This is an automated email from the ASF dual-hosted git repository.
mbutrovich pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 79b83d873 chore: Refactor JVM shuffle: Move `SpillSorter` to top level
class and add tests (#3081)
79b83d873 is described below
commit 79b83d8733d976310552517f042c498c63b97592
Author: Andy Grove <[email protected]>
AuthorDate: Wed Jan 14 13:03:10 2026 -0700
chore: Refactor JVM shuffle: Move `SpillSorter` to top level class and add
tests (#3081)
---
.github/workflows/pr_build_linux.yml | 1 +
.github/workflows/pr_build_macos.yml | 1 +
dev/ensure-jars-have-correct-contents.sh | 1 +
.../shuffle/sort/CometShuffleExternalSorter.java | 268 ++--------------
.../org/apache/spark/shuffle/sort/SpillSorter.java | 352 +++++++++++++++++++++
.../spark/shuffle/sort/SpillSorterSuite.scala | 262 +++++++++++++++
6 files changed, 638 insertions(+), 247 deletions(-)
diff --git a/.github/workflows/pr_build_linux.yml
b/.github/workflows/pr_build_linux.yml
index 9f5324b26..8e4dc5124 100644
--- a/.github/workflows/pr_build_linux.yml
+++ b/.github/workflows/pr_build_linux.yml
@@ -122,6 +122,7 @@ jobs:
org.apache.comet.exec.CometAsyncShuffleSuite
org.apache.comet.exec.DisableAQECometShuffleSuite
org.apache.comet.exec.DisableAQECometAsyncShuffleSuite
+ org.apache.spark.shuffle.sort.SpillSorterSuite
- name: "parquet"
value: |
org.apache.comet.parquet.CometParquetWriterSuite
diff --git a/.github/workflows/pr_build_macos.yml
b/.github/workflows/pr_build_macos.yml
index 58ba48134..f94071dbc 100644
--- a/.github/workflows/pr_build_macos.yml
+++ b/.github/workflows/pr_build_macos.yml
@@ -85,6 +85,7 @@ jobs:
org.apache.comet.exec.CometAsyncShuffleSuite
org.apache.comet.exec.DisableAQECometShuffleSuite
org.apache.comet.exec.DisableAQECometAsyncShuffleSuite
+ org.apache.spark.shuffle.sort.SpillSorterSuite
- name: "parquet"
value: |
org.apache.comet.parquet.CometParquetWriterSuite
diff --git a/dev/ensure-jars-have-correct-contents.sh
b/dev/ensure-jars-have-correct-contents.sh
index f698fe78f..570aeabb2 100755
--- a/dev/ensure-jars-have-correct-contents.sh
+++ b/dev/ensure-jars-have-correct-contents.sh
@@ -86,6 +86,7 @@ allowed_expr+="|^org/apache/spark/shuffle/$"
allowed_expr+="|^org/apache/spark/shuffle/sort/$"
allowed_expr+="|^org/apache/spark/shuffle/sort/CometShuffleExternalSorter.*$"
allowed_expr+="|^org/apache/spark/shuffle/sort/RowPartition.class$"
+allowed_expr+="|^org/apache/spark/shuffle/sort/SpillSorter.*$"
allowed_expr+="|^org/apache/spark/shuffle/comet/.*$"
allowed_expr+="|^org/apache/spark/sql/$"
# allow ExplainPlanGenerator trait since it may not be available in older
Spark versions
diff --git
a/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java
b/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java
index 8bc22b342..b026c6bc4 100644
---
a/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java
+++
b/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java
@@ -23,7 +23,6 @@ import java.io.File;
import java.io.IOException;
import java.util.LinkedList;
import java.util.concurrent.*;
-import javax.annotation.Nullable;
import scala.Tuple2;
@@ -32,7 +31,6 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
-import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.memory.SparkOutOfMemoryError;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.comet.CometShuffleChecksumSupport;
@@ -41,17 +39,14 @@ import org.apache.spark.shuffle.comet.TooLargePageException;
import org.apache.spark.sql.comet.execution.shuffle.CometUnsafeShuffleWriter;
import org.apache.spark.sql.comet.execution.shuffle.ShuffleThreadPool;
import org.apache.spark.sql.comet.execution.shuffle.SpillInfo;
-import org.apache.spark.sql.comet.execution.shuffle.SpillWriter;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.TempShuffleBlockId;
-import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.UnsafeAlignedOffset;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.util.Utils;
import org.apache.comet.CometConf$;
-import org.apache.comet.Native;
/**
* An external sorter that is specialized for sort-based shuffle.
@@ -169,10 +164,28 @@ public final class CometShuffleExternalSorter implements
CometShuffleChecksumSup
this.threadPool = null;
}
- this.activeSpillSorter = new SpillSorter();
-
this.preferDictionaryRatio =
(double)
CometConf$.MODULE$.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO().get();
+
+ this.activeSpillSorter = createSpillSorter();
+ }
+
+ /** Creates a new SpillSorter with all required dependencies. */
+ private SpillSorter createSpillSorter() {
+ return new SpillSorter(
+ allocator,
+ initialSize,
+ schema,
+ uaoSize,
+ preferDictionaryRatio,
+ compressionCodec,
+ compressionLevel,
+ checksumAlgorithm,
+ partitionChecksums,
+ writeMetrics,
+ taskContext,
+ spills,
+ this::spill);
}
public long[] getChecksums() {
@@ -237,7 +250,7 @@ public final class CometShuffleExternalSorter implements
CometShuffleChecksumSup
}
}
- activeSpillSorter = new SpillSorter();
+ activeSpillSorter = createSpillSorter();
} else {
activeSpillSorter.writeSortedFileNative(false, tracingEnabled);
final long spillSize = activeSpillSorter.freeMemory();
@@ -410,243 +423,4 @@ public final class CometShuffleExternalSorter implements
CometShuffleChecksumSup
return spills.toArray(new SpillInfo[spills.size()]);
}
-
- class SpillSorter extends SpillWriter {
- private boolean freed = false;
-
- private SpillInfo spillInfo;
-
- // These variables are reset after spilling:
- @Nullable private ShuffleInMemorySorter inMemSorter;
-
- // This external sorter can call native code to sort partition ids and
record pointers of rows.
- // In order to do that, we need pass the address of the internal array in
the sorter to native.
- // But we cannot access it as it is private member in the Spark sorter.
Instead, we allocate
- // the array and assign the pointer array in the sorter.
- private LongArray sorterArray;
-
- SpillSorter() {
- this.spillInfo = null;
-
- this.allocator = CometShuffleExternalSorter.this.allocator;
-
- // Allocate array for in-memory sorter.
- // As we cannot access the address of the internal array in the sorter,
so we need to
- // allocate the array manually and expand the pointer array in the
sorter.
- // We don't want in-memory sorter to allocate memory but the initial
size cannot be zero.
- try {
- this.inMemSorter = new ShuffleInMemorySorter(allocator, 1, true);
- } catch (java.lang.IllegalAccessError e) {
- throw new java.lang.RuntimeException(
- "Error loading in-memory sorter check class path -- see "
- +
"https://github.com/apache/arrow-datafusion-comet?tab=readme-ov-file#enable-comet-shuffle",
- e);
- }
- sorterArray = allocator.allocateArray(initialSize);
- this.inMemSorter.expandPointerArray(sorterArray);
-
- this.allocatedPages = new LinkedList<>();
-
- this.nativeLib = new Native();
- this.dataTypes = serializeSchema(schema);
- }
-
- /** Frees allocated memory pages of this writer */
- @Override
- public long freeMemory() {
- // We need to synchronize here because we may get the memory usage by
calling
- // this method in the task thread.
- synchronized (this) {
- return super.freeMemory();
- }
- }
-
- @Override
- public long getMemoryUsage() {
- // We need to synchronize here because we may free the memory pages in
another thread,
- // i.e. when spilling, but this method may be called in the task thread.
- synchronized (this) {
- long totalPageSize = super.getMemoryUsage();
-
- if (freed) {
- return totalPageSize;
- } else {
- return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) +
totalPageSize;
- }
- }
- }
-
- @Override
- protected void spill(int required) throws IOException {
- CometShuffleExternalSorter.this.spill();
- }
-
- /** Free the pointer array held by this sorter. */
- public void freeArray() {
- synchronized (this) {
- inMemSorter.free();
- freed = true;
- }
- }
-
- /**
- * Reset the in-memory sorter's pointer array only after freeing up the
memory pages holding the
- * records.
- */
- public void reset() {
- // We allocate pointer array outside the sorter.
- // So we can get array address which can be used by native code.
- inMemSorter.reset();
- sorterArray = allocator.allocateArray(initialSize);
- inMemSorter.expandPointerArray(sorterArray);
- }
-
- void setSpillInfo(SpillInfo spillInfo) {
- this.spillInfo = spillInfo;
- }
-
- public int numRecords() {
- return this.inMemSorter.numRecords();
- }
-
- public void writeSortedFileNative(boolean isLastFile, boolean
tracingEnabled)
- throws IOException {
- // This call performs the actual sort.
- long arrayAddr = this.sorterArray.getBaseOffset();
- int pos = inMemSorter.numRecords();
- nativeLib.sortRowPartitionsNative(arrayAddr, pos, tracingEnabled);
- ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords =
- new ShuffleInMemorySorter.ShuffleSorterIterator(pos,
this.sorterArray, 0);
-
- // If there are no sorted records, so we don't need to create an empty
spill file.
- if (!sortedRecords.hasNext()) {
- return;
- }
-
- final ShuffleWriteMetricsReporter writeMetricsToUse;
-
- if (isLastFile) {
- // We're writing the final non-spill file, so we _do_ want to count
this as shuffle bytes.
- writeMetricsToUse = writeMetrics;
- } else {
- // We're spilling, so bytes written should be counted towards spill
rather than write.
- // Create a dummy WriteMetrics object to absorb these metrics, since
we don't want to count
- // them towards shuffle bytes written.
- writeMetricsToUse = new ShuffleWriteMetrics();
- }
-
- int currentPartition = -1;
-
- final RowPartition rowPartition = new RowPartition(initialSize);
-
- while (sortedRecords.hasNext()) {
- sortedRecords.loadNext();
- final int partition =
sortedRecords.packedRecordPointer.getPartitionId();
- assert (partition >= currentPartition);
- if (partition != currentPartition) {
- // Switch to the new partition
- if (currentPartition != -1) {
-
- if (partitionChecksums.length > 0) {
- // If checksum is enabled, we need to update the checksum for
the current partition.
- setChecksum(partitionChecksums[currentPartition]);
- setChecksumAlgo(checksumAlgorithm);
- }
-
- long written =
- doSpilling(
- dataTypes,
- spillInfo.file,
- rowPartition,
- writeMetricsToUse,
- preferDictionaryRatio,
- compressionCodec,
- compressionLevel,
- tracingEnabled);
- spillInfo.partitionLengths[currentPartition] = written;
-
- // Store the checksum for the current partition.
- partitionChecksums[currentPartition] = getChecksum();
- }
- currentPartition = partition;
- }
-
- final long recordPointer =
sortedRecords.packedRecordPointer.getRecordPointer();
- final long recordOffsetInPage =
allocator.getOffsetInPage(recordPointer);
- // Note that we need to skip over record key (partition id)
- // Note that we already use off-heap memory for serialized rows, so
recordPage is always
- // null.
- int recordSizeInBytes = UnsafeAlignedOffset.getSize(null,
recordOffsetInPage) - 4;
- long recordReadPosition = recordOffsetInPage + uaoSize + 4; // skip
over record length too
- rowPartition.addRow(recordReadPosition, recordSizeInBytes);
- }
-
- if (currentPartition != -1) {
- long written =
- doSpilling(
- dataTypes,
- spillInfo.file,
- rowPartition,
- writeMetricsToUse,
- preferDictionaryRatio,
- compressionCodec,
- compressionLevel,
- tracingEnabled);
- spillInfo.partitionLengths[currentPartition] = written;
-
- synchronized (spills) {
- spills.add(spillInfo);
- }
- }
-
- if (!isLastFile) { // i.e. this is a spill file
- // The current semantics of `shuffleRecordsWritten` seem to be that
it's updated when
- // records
- // are written to disk, not when they enter the shuffle sorting code.
DiskBlockObjectWriter
- // relies on its `recordWritten()` method being called in order to
trigger periodic updates
- // to
- // `shuffleBytesWritten`. If we were to remove the `recordWritten()`
call and increment that
- // counter at a higher-level, then the in-progress metrics for records
written and bytes
- // written would get out of sync.
- //
- // When writing the last file, we pass `writeMetrics` directly to the
DiskBlockObjectWriter;
- // in all other cases, we pass in a dummy write metrics to capture
metrics, then copy those
- // metrics to the true write metrics here. The reason for performing
this copying is so that
- // we can avoid reporting spilled bytes as shuffle write bytes.
- //
- // Note that we intentionally ignore the value of
`writeMetricsToUse.shuffleWriteTime()`.
- // Consistent with ExternalSorter, we do not count this IO towards
shuffle write time.
- // SPARK-3577 tracks the spill time separately.
-
- // This is guaranteed to be a ShuffleWriteMetrics based on the if
check in the beginning
- // of this method.
- synchronized (writeMetrics) {
- writeMetrics.incRecordsWritten(
- ((ShuffleWriteMetrics) writeMetricsToUse).recordsWritten());
- taskContext
- .taskMetrics()
- .incDiskBytesSpilled(((ShuffleWriteMetrics)
writeMetricsToUse).bytesWritten());
- }
- }
- }
-
- public boolean hasSpaceForAnotherRecord() {
- return inMemSorter.hasSpaceForAnotherRecord();
- }
-
- public void expandPointerArray(LongArray newArray) {
- inMemSorter.expandPointerArray(newArray);
- this.sorterArray = newArray;
- }
-
- public void insertRecord(Object recordBase, long recordOffset, int length,
int partitionId) {
- final Object base = currentPage.getBaseObject();
- final long recordAddress =
allocator.encodePageNumberAndOffset(currentPage, pageCursor);
- UnsafeAlignedOffset.putSize(base, pageCursor, length);
- pageCursor += uaoSize;
- Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
- pageCursor += length;
- inMemSorter.insertRecord(recordAddress, partitionId);
- }
- }
}
diff --git a/spark/src/main/java/org/apache/spark/shuffle/sort/SpillSorter.java
b/spark/src/main/java/org/apache/spark/shuffle/sort/SpillSorter.java
new file mode 100644
index 000000000..36b50e620
--- /dev/null
+++ b/spark/src/main/java/org/apache/spark/shuffle/sort/SpillSorter.java
@@ -0,0 +1,352 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.io.IOException;
+import java.util.LinkedList;
+import javax.annotation.Nullable;
+
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
+import org.apache.spark.shuffle.comet.CometShuffleMemoryAllocatorTrait;
+import org.apache.spark.sql.comet.execution.shuffle.SpillInfo;
+import org.apache.spark.sql.comet.execution.shuffle.SpillWriter;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.UnsafeAlignedOffset;
+import org.apache.spark.unsafe.array.LongArray;
+
+import org.apache.comet.Native;
+
+/**
+ * A spill sorter that buffers records in memory, sorts them by partition ID,
and writes them to
+ * disk. This class is used by CometShuffleExternalSorter to manage individual
spill operations.
+ *
+ * <p>Each SpillSorter instance manages its own memory pages and pointer
array. When spilling is
+ * triggered, the records are sorted by partition ID using native code and
written to a spill file.
+ */
+public class SpillSorter extends SpillWriter {
+
+ /** Callback interface for triggering spill operations in the parent sorter.
*/
+ @FunctionalInterface
+ public interface SpillCallback {
+ void onSpillRequired() throws IOException;
+ }
+
+ // Configuration fields (immutable after construction)
+ private final int initialSize;
+ private final int uaoSize;
+ private final double preferDictionaryRatio;
+ private final String compressionCodec;
+ private final int compressionLevel;
+ private final String checksumAlgorithm;
+
+ // Shared state (mutable, passed by reference from parent)
+ private final long[] partitionChecksums;
+ private final ShuffleWriteMetricsReporter writeMetrics;
+ private final TaskContext taskContext;
+ private final LinkedList<SpillInfo> spills;
+ private final SpillCallback spillCallback;
+
+ // Internal state
+ private boolean freed = false;
+ private SpillInfo spillInfo;
+ @Nullable private ShuffleInMemorySorter inMemSorter;
+ private LongArray sorterArray;
+
+ /**
+ * Creates a new SpillSorter with explicit dependencies.
+ *
+ * @param allocator Memory allocator for pages and arrays
+ * @param initialSize Initial size for the pointer array
+ * @param schema Schema of the records being sorted
+ * @param uaoSize Size of UnsafeAlignedOffset (4 or 8 bytes)
+ * @param preferDictionaryRatio Dictionary encoding preference ratio
+ * @param compressionCodec Compression codec for spill files
+ * @param compressionLevel Compression level
+ * @param checksumAlgorithm Checksum algorithm (e.g., "crc32", "adler32")
+ * @param partitionChecksums Array to store partition checksums (shared with
parent)
+ * @param writeMetrics Metrics reporter for shuffle writes
+ * @param taskContext Task context for metrics updates
+ * @param spills List to accumulate spill info (shared with parent)
+ * @param spillCallback Callback to trigger spill in parent sorter
+ */
+ public SpillSorter(
+ CometShuffleMemoryAllocatorTrait allocator,
+ int initialSize,
+ StructType schema,
+ int uaoSize,
+ double preferDictionaryRatio,
+ String compressionCodec,
+ int compressionLevel,
+ String checksumAlgorithm,
+ long[] partitionChecksums,
+ ShuffleWriteMetricsReporter writeMetrics,
+ TaskContext taskContext,
+ LinkedList<SpillInfo> spills,
+ SpillCallback spillCallback) {
+
+ this.initialSize = initialSize;
+ this.uaoSize = uaoSize;
+ this.preferDictionaryRatio = preferDictionaryRatio;
+ this.compressionCodec = compressionCodec;
+ this.compressionLevel = compressionLevel;
+ this.checksumAlgorithm = checksumAlgorithm;
+ this.partitionChecksums = partitionChecksums;
+ this.writeMetrics = writeMetrics;
+ this.taskContext = taskContext;
+ this.spills = spills;
+ this.spillCallback = spillCallback;
+
+ this.spillInfo = null;
+ this.allocator = allocator;
+
+ // Allocate array for in-memory sorter.
+ // As we cannot access the address of the internal array in the sorter, so
we need to
+ // allocate the array manually and expand the pointer array in the sorter.
+ // We don't want in-memory sorter to allocate memory but the initial size
cannot be zero.
+ try {
+ this.inMemSorter = new ShuffleInMemorySorter(allocator, 1, true);
+ } catch (java.lang.IllegalAccessError e) {
+ throw new java.lang.RuntimeException(
+ "Error loading in-memory sorter check class path -- see "
+ +
"https://github.com/apache/arrow-datafusion-comet?tab=readme-ov-file#enable-comet-shuffle",
+ e);
+ }
+ sorterArray = allocator.allocateArray(initialSize);
+ this.inMemSorter.expandPointerArray(sorterArray);
+
+ this.allocatedPages = new LinkedList<>();
+
+ this.nativeLib = new Native();
+ this.dataTypes = serializeSchema(schema);
+ }
+
+ /** Frees allocated memory pages of this writer */
+ @Override
+ public long freeMemory() {
+ // We need to synchronize here because we may get the memory usage by
calling
+ // this method in the task thread.
+ synchronized (this) {
+ return super.freeMemory();
+ }
+ }
+
+ @Override
+ public long getMemoryUsage() {
+ // We need to synchronize here because we may free the memory pages in
another thread,
+ // i.e. when spilling, but this method may be called in the task thread.
+ synchronized (this) {
+ long totalPageSize = super.getMemoryUsage();
+
+ if (freed) {
+ return totalPageSize;
+ } else {
+ return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) +
totalPageSize;
+ }
+ }
+ }
+
+ @Override
+ protected void spill(int required) throws IOException {
+ spillCallback.onSpillRequired();
+ }
+
+ /** Free the pointer array held by this sorter. */
+ public void freeArray() {
+ synchronized (this) {
+ inMemSorter.free();
+ freed = true;
+ }
+ }
+
+ /**
+ * Reset the in-memory sorter's pointer array only after freeing up the
memory pages holding the
+ * records.
+ */
+ public void reset() {
+ synchronized (this) {
+ // We allocate pointer array outside the sorter.
+ // So we can get array address which can be used by native code.
+ inMemSorter.reset();
+ sorterArray = allocator.allocateArray(initialSize);
+ inMemSorter.expandPointerArray(sorterArray);
+ freed = false;
+ }
+ }
+
+ void setSpillInfo(SpillInfo spillInfo) {
+ this.spillInfo = spillInfo;
+ }
+
+ public int numRecords() {
+ return this.inMemSorter.numRecords();
+ }
+
+ public void writeSortedFileNative(boolean isLastFile, boolean
tracingEnabled) throws IOException {
+ // This call performs the actual sort.
+ long arrayAddr = this.sorterArray.getBaseOffset();
+ int pos = inMemSorter.numRecords();
+ nativeLib.sortRowPartitionsNative(arrayAddr, pos, tracingEnabled);
+ ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords =
+ new ShuffleInMemorySorter.ShuffleSorterIterator(pos, this.sorterArray,
0);
+
+ // If there are no sorted records, so we don't need to create an empty
spill file.
+ if (!sortedRecords.hasNext()) {
+ return;
+ }
+
+ final ShuffleWriteMetricsReporter writeMetricsToUse;
+
+ if (isLastFile) {
+ // We're writing the final non-spill file, so we _do_ want to count this
as shuffle bytes.
+ writeMetricsToUse = writeMetrics;
+ } else {
+ // We're spilling, so bytes written should be counted towards spill
rather than write.
+ // Create a dummy WriteMetrics object to absorb these metrics, since we
don't want to count
+ // them towards shuffle bytes written.
+ writeMetricsToUse = new ShuffleWriteMetrics();
+ }
+
+ int currentPartition = -1;
+
+ final RowPartition rowPartition = new RowPartition(initialSize);
+
+ while (sortedRecords.hasNext()) {
+ sortedRecords.loadNext();
+ final int partition = sortedRecords.packedRecordPointer.getPartitionId();
+ assert (partition >= currentPartition);
+ if (partition != currentPartition) {
+ // Switch to the new partition
+ if (currentPartition != -1) {
+
+ if (partitionChecksums.length > 0) {
+ // If checksum is enabled, we need to update the checksum for the
current partition.
+ setChecksum(partitionChecksums[currentPartition]);
+ setChecksumAlgo(checksumAlgorithm);
+ }
+
+ long written =
+ doSpilling(
+ dataTypes,
+ spillInfo.file,
+ rowPartition,
+ writeMetricsToUse,
+ preferDictionaryRatio,
+ compressionCodec,
+ compressionLevel,
+ tracingEnabled);
+ spillInfo.partitionLengths[currentPartition] = written;
+
+ // Store the checksum for the current partition.
+ partitionChecksums[currentPartition] = getChecksum();
+ }
+ currentPartition = partition;
+ }
+
+ final long recordPointer =
sortedRecords.packedRecordPointer.getRecordPointer();
+ final long recordOffsetInPage = allocator.getOffsetInPage(recordPointer);
+ // Note that we need to skip over record key (partition id)
+ // Note that we already use off-heap memory for serialized rows, so
recordPage is always
+ // null.
+ int recordSizeInBytes = UnsafeAlignedOffset.getSize(null,
recordOffsetInPage) - 4;
+ long recordReadPosition = recordOffsetInPage + uaoSize + 4; // skip over
record length too
+ rowPartition.addRow(recordReadPosition, recordSizeInBytes);
+ }
+
+ if (currentPartition != -1) {
+ if (partitionChecksums.length > 0) {
+ // If checksum is enabled, we need to update the checksum for the last
partition.
+ setChecksum(partitionChecksums[currentPartition]);
+ setChecksumAlgo(checksumAlgorithm);
+ }
+
+ long written =
+ doSpilling(
+ dataTypes,
+ spillInfo.file,
+ rowPartition,
+ writeMetricsToUse,
+ preferDictionaryRatio,
+ compressionCodec,
+ compressionLevel,
+ tracingEnabled);
+ spillInfo.partitionLengths[currentPartition] = written;
+
+ // Store the checksum for the last partition.
+ if (partitionChecksums.length > 0) {
+ partitionChecksums[currentPartition] = getChecksum();
+ }
+
+ synchronized (spills) {
+ spills.add(spillInfo);
+ }
+ }
+
+ if (!isLastFile) { // i.e. this is a spill file
+ // The current semantics of `shuffleRecordsWritten` seem to be that it's
updated when
+ // records
+ // are written to disk, not when they enter the shuffle sorting code.
DiskBlockObjectWriter
+ // relies on its `recordWritten()` method being called in order to
trigger periodic updates
+ // to
+ // `shuffleBytesWritten`. If we were to remove the `recordWritten()`
call and increment that
+ // counter at a higher-level, then the in-progress metrics for records
written and bytes
+ // written would get out of sync.
+ //
+ // When writing the last file, we pass `writeMetrics` directly to the
DiskBlockObjectWriter;
+ // in all other cases, we pass in a dummy write metrics to capture
metrics, then copy those
+ // metrics to the true write metrics here. The reason for performing
this copying is so that
+ // we can avoid reporting spilled bytes as shuffle write bytes.
+ //
+ // Note that we intentionally ignore the value of
`writeMetricsToUse.shuffleWriteTime()`.
+ // Consistent with ExternalSorter, we do not count this IO towards
shuffle write time.
+ // SPARK-3577 tracks the spill time separately.
+
+ // This is guaranteed to be a ShuffleWriteMetrics based on the if check
in the beginning
+ // of this method.
+ synchronized (writeMetrics) {
+ writeMetrics.incRecordsWritten(((ShuffleWriteMetrics)
writeMetricsToUse).recordsWritten());
+ taskContext
+ .taskMetrics()
+ .incDiskBytesSpilled(((ShuffleWriteMetrics)
writeMetricsToUse).bytesWritten());
+ }
+ }
+ }
+
+ public boolean hasSpaceForAnotherRecord() {
+ return inMemSorter.hasSpaceForAnotherRecord();
+ }
+
+ public void expandPointerArray(LongArray newArray) {
+ inMemSorter.expandPointerArray(newArray);
+ this.sorterArray = newArray;
+ }
+
+ public void insertRecord(Object recordBase, long recordOffset, int length,
int partitionId) {
+ final Object base = currentPage.getBaseObject();
+ final long recordAddress =
allocator.encodePageNumberAndOffset(currentPage, pageCursor);
+ UnsafeAlignedOffset.putSize(base, pageCursor, length);
+ pageCursor += uaoSize;
+ Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
+ pageCursor += length;
+ inMemSorter.insertRecord(recordAddress, partitionId);
+ }
+}
diff --git
a/spark/src/test/scala/org/apache/spark/shuffle/sort/SpillSorterSuite.scala
b/spark/src/test/scala/org/apache/spark/shuffle/sort/SpillSorterSuite.scala
new file mode 100644
index 000000000..dfbe38b64
--- /dev/null
+++ b/spark/src/test/scala/org/apache/spark/shuffle/sort/SpillSorterSuite.scala
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.shuffle.sort
+
+import java.util.concurrent.atomic.AtomicInteger
+
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.spark.{SparkConf, TaskContext}
+import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
+import org.apache.spark.shuffle.comet.CometShuffleMemoryAllocator
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.Platform
+import org.apache.spark.unsafe.UnsafeAlignedOffset
+
+/**
+ * Unit tests for [[SpillSorter]].
+ *
+ * These tests verify SpillSorter behavior using Spark's test memory
management infrastructure,
+ * without needing a full SparkContext.
+ */
+class SpillSorterSuite extends AnyFunSuite with BeforeAndAfterEach {
+
+ private val INITIAL_SIZE = 1024
+ private val UAO_SIZE = UnsafeAlignedOffset.getUaoSize
+ private val PAGE_SIZE = 4 * 1024 * 1024 // 4MB
+
+ private var conf: SparkConf = _
+ private var memoryManager: TestMemoryManager = _
+ private var taskMemoryManager: TaskMemoryManager = _
+
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ conf = new SparkConf()
+ .set("spark.memory.offHeap.enabled", "false")
+ memoryManager = new TestMemoryManager(conf)
+ memoryManager.limit(100 * 1024 * 1024) // 100MB
+ taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
+ }
+
+ override def afterEach(): Unit = {
+ if (taskMemoryManager != null) {
+ taskMemoryManager.cleanUpAllAllocatedMemory()
+ taskMemoryManager = null
+ }
+ memoryManager = null
+ super.afterEach()
+ }
+
+ private def createTestSchema(): StructType = {
+ new StructType().add("id", IntegerType)
+ }
+
+ private def createSpillSorter(
+ spillCallback: SpillSorter.SpillCallback = () => {},
+ spills:
java.util.LinkedList[org.apache.spark.sql.comet.execution.shuffle.SpillInfo] =
+ new
java.util.LinkedList[org.apache.spark.sql.comet.execution.shuffle.SpillInfo](),
+ partitionChecksums: Array[Long] = new Array[Long](10)): SpillSorter = {
+ val allocator = CometShuffleMemoryAllocator.getInstance(conf,
taskMemoryManager, PAGE_SIZE)
+ val schema = createTestSchema()
+ val writeMetrics = new ShuffleWriteMetrics()
+ val taskContext = TaskContext.empty()
+
+ new SpillSorter(
+ allocator,
+ INITIAL_SIZE,
+ schema,
+ UAO_SIZE,
+ 0.5, // preferDictionaryRatio
+ "zstd", // compressionCodec
+ 1, // compressionLevel
+ "adler32", // checksumAlgorithm
+ partitionChecksums,
+ writeMetrics,
+ taskContext,
+ spills,
+ spillCallback)
+ }
+
+ test("initial state") {
+ val sorter = createSpillSorter()
+ try {
+ assert(sorter.numRecords() === 0)
+ assert(sorter.hasSpaceForAnotherRecord())
+ assert(sorter.getMemoryUsage() > 0)
+ } finally {
+ sorter.freeMemory()
+ sorter.freeArray()
+ }
+ }
+
+ test("insert single record") {
+ val sorter = createSpillSorter()
+ try {
+ val recordData = Array[Byte](1, 2, 3, 4)
+ val partitionId = 0
+
+ sorter.initialCurrentPage(recordData.length + UAO_SIZE)
+ sorter.insertRecord(recordData, Platform.BYTE_ARRAY_OFFSET,
recordData.length, partitionId)
+
+ assert(sorter.numRecords() === 1)
+ } finally {
+ sorter.freeMemory()
+ sorter.freeArray()
+ }
+ }
+
+ test("insert multiple records") {
+ val sorter = createSpillSorter()
+ try {
+ val recordData = Array[Byte](1, 2, 3, 4)
+ val numRecords = 100
+
+ sorter.initialCurrentPage(numRecords * (recordData.length + UAO_SIZE))
+
+ for (i <- 0 until numRecords) {
+ val partitionId = i % 10
+ sorter.insertRecord(
+ recordData,
+ Platform.BYTE_ARRAY_OFFSET,
+ recordData.length,
+ partitionId)
+ }
+
+ assert(sorter.numRecords() === numRecords)
+ } finally {
+ sorter.freeMemory()
+ sorter.freeArray()
+ }
+ }
+
+ test("reset after free") {
+ val sorter = createSpillSorter()
+ try {
+ val recordData = Array[Byte](1, 2, 3, 4)
+ sorter.initialCurrentPage(recordData.length + UAO_SIZE)
+ sorter.insertRecord(recordData, Platform.BYTE_ARRAY_OFFSET,
recordData.length, 0)
+
+ assert(sorter.numRecords() === 1)
+
+ sorter.freeMemory()
+ sorter.reset()
+
+ assert(sorter.numRecords() === 0)
+ assert(sorter.hasSpaceForAnotherRecord())
+ } finally {
+ sorter.freeMemory()
+ sorter.freeArray()
+ }
+ }
+
+ test("free memory returns correct value") {
+ val sorter = createSpillSorter()
+ try {
+ sorter.initialCurrentPage(1024)
+ val memoryBefore = sorter.getMemoryUsage()
+ assert(memoryBefore > 0)
+
+ val freed = sorter.freeMemory()
+ assert(freed > 0)
+ } finally {
+ sorter.freeArray()
+ }
+ }
+
+ test("spill callback not triggered during normal operations") {
+ val spillCount = new AtomicInteger(0)
+ val callback: SpillSorter.SpillCallback = () =>
spillCount.incrementAndGet()
+
+ val sorter = createSpillSorter(spillCallback = callback)
+ try {
+ sorter.initialCurrentPage(1024)
+ val recordData = Array[Byte](1, 2, 3, 4)
+ sorter.insertRecord(recordData, Platform.BYTE_ARRAY_OFFSET,
recordData.length, 0)
+
+ assert(spillCount.get() === 0, "Spill callback should not be triggered
during normal ops")
+ } finally {
+ sorter.freeMemory()
+ sorter.freeArray()
+ }
+ }
+
+ test("getMemoryUsage is thread-safe") {
+ val sorter = createSpillSorter()
+ try {
+ sorter.initialCurrentPage(1024)
+
+ val threads = (0 until 10).map { _ =>
+ new Thread(() => {
+ for (_ <- 0 until 100) {
+ sorter.getMemoryUsage()
+ }
+ })
+ }
+
+ threads.foreach(_.start())
+ threads.foreach(_.join())
+ // Test passes if no exceptions thrown
+ } finally {
+ sorter.freeMemory()
+ sorter.freeArray()
+ }
+ }
+
+ test("expand pointer array") {
+ val sorter = createSpillSorter()
+ try {
+ val initialMemory = sorter.getMemoryUsage()
+ val allocator = CometShuffleMemoryAllocator.getInstance(conf,
taskMemoryManager, PAGE_SIZE)
+ val newArray = allocator.allocateArray(INITIAL_SIZE * 2)
+ sorter.expandPointerArray(newArray)
+
+ assert(sorter.getMemoryUsage() >= initialMemory)
+ } finally {
+ sorter.freeMemory()
+ sorter.freeArray()
+ }
+ }
+
+ test("records distributed across partitions") {
+ val sorter = createSpillSorter()
+ try {
+ val recordData = Array[Byte](1, 2, 3, 4)
+ val numPartitions = 5
+ val recordsPerPartition = 20
+
+ sorter.initialCurrentPage(
+ numPartitions * recordsPerPartition * (recordData.length + UAO_SIZE))
+
+ for (p <- 0 until numPartitions) {
+ for (_ <- 0 until recordsPerPartition) {
+ sorter.insertRecord(recordData, Platform.BYTE_ARRAY_OFFSET,
recordData.length, p)
+ }
+ }
+
+ assert(sorter.numRecords() === numPartitions * recordsPerPartition)
+ } finally {
+ sorter.freeMemory()
+ sorter.freeArray()
+ }
+ }
+
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]