This is an automated email from the ASF dual-hosted git repository.
zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 1f809ed26 [#2718] feat(spark): Eliminate copy in WriterBuffer when
compression off for Gluten (#2720)
1f809ed26 is described below
commit 1f809ed263cb96dffded0988848977689c058d04
Author: Junfan Zhang <[email protected]>
AuthorDate: Fri Jan 30 16:46:43 2026 +0800
[#2718] feat(spark): Eliminate copy in WriterBuffer when compression off
for Gluten (#2720)
### What changes were proposed in this pull request?
- This change introduces `CompositeByteBuf` in `WriterBuffer` to provide a
zero-copy data view, which is especially beneficial in Gluten scenarios where
compression is disabled and handled by the Gluten side.
- In addition, the CRC32 generator is enhanced to operate directly on
ByteBuf.
For better to achieve zero-copy, we'd better to accept `ByteBuf` in codec,
this could be finished in the next PRs.
### Why are the changes needed?
for #2718
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Unit tests
---
.../spark/shuffle/writer/WriteBufferManager.java | 38 ++++++++------
.../apache/spark/shuffle/writer/WriterBuffer.java | 39 +++++++++++----
.../apache/uniffle/common/ShuffleBlockInfo.java | 54 ++++++++++++++++++++
.../apache/uniffle/common/util/ChecksumUtils.java | 23 +++++++++
.../uniffle/common/util/ChecksumUtilsTest.java | 58 ++++++++++++++++++++++
5 files changed, 186 insertions(+), 26 deletions(-)
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index 0321311a6..b50cc5d56 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -36,6 +36,8 @@ import scala.reflect.ManifestFactory$;
import com.clearspring.analytics.util.Lists;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Maps;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.MemoryMode;
@@ -423,10 +425,8 @@ public class WriteBufferManager extends MemoryConsumer {
protected ShuffleBlockInfo createDeferredCompressedBlock(
int partitionId, WriterBuffer writerBuffer) {
- byte[] data = writerBuffer.getData();
- final int uncompressLength = data.length;
+ final int uncompressLength = writerBuffer.getDataLength();
final int memoryUsed = writerBuffer.getMemoryUsed();
- final long records = writerBuffer.getRecordCount();
this.blockCounter.incrementAndGet();
this.uncompressedDataLen += uncompressLength;
@@ -435,12 +435,15 @@ public class WriteBufferManager extends MemoryConsumer {
final long blockId =
blockIdLayout.getBlockId(getNextSeqNo(partitionId), partitionId,
taskAttemptId);
+ // todo: support ByteBuf compress directly to avoid copying
+ final byte[] rawData = writerBuffer.getData();
+
Function<DeferredCompressedBlock, DeferredCompressedBlock> rebuildFunction
=
block -> {
- byte[] compressed = data;
+ byte[] compressed = rawData;
if (codec.isPresent()) {
long start = System.currentTimeMillis();
- compressed = codec.get().compress(data);
+ compressed = codec.get().compress(rawData);
this.compressTime += System.currentTimeMillis() - start;
}
this.compressedDataLen += compressed.length;
@@ -451,9 +454,9 @@ public class WriteBufferManager extends MemoryConsumer {
return block;
};
- int estimatedCompressedSize = data.length;
+ int estimatedCompressedSize = uncompressLength;
if (codec.isPresent()) {
- estimatedCompressedSize = codec.get().maxCompressedLength(data.length);
+ estimatedCompressedSize =
codec.get().maxCompressedLength(uncompressLength);
}
return new DeferredCompressedBlock(
@@ -467,7 +470,7 @@ public class WriteBufferManager extends MemoryConsumer {
partitionAssignmentRetrieveFunc,
rebuildFunction,
estimatedCompressedSize,
- records);
+ writerBuffer.getRecordCount());
}
// transform records to shuffleBlock
@@ -476,28 +479,31 @@ public class WriteBufferManager extends MemoryConsumer {
return createDeferredCompressedBlock(partitionId, wb);
}
- byte[] data = wb.getData();
- final int uncompressLength = data.length;
- byte[] compressed = data;
+ final int uncompressLength = wb.getDataLength();
+ final ByteBuf data = wb.getDataAsByteBuf();
+ ByteBuf compressed = data;
if (codec.isPresent()) {
long start = System.currentTimeMillis();
- compressed = codec.get().compress(data);
+ // todo: support ByteBuf compress directly to avoid copying
+ byte[] compressedByteArr = codec.get().compress(wb.getData());
+ compressed = Unpooled.wrappedBuffer(compressedByteArr);
compressTime += System.currentTimeMillis() - start;
}
final long crc32 = ChecksumUtils.getCrc32(compressed);
final long blockId =
blockIdLayout.getBlockId(getNextSeqNo(partitionId), partitionId,
taskAttemptId);
blockCounter.incrementAndGet();
- uncompressedDataLen += data.length;
- compressedDataLen += compressed.length;
- shuffleWriteMetrics.incBytesWritten(compressed.length);
+ final int compressedLen = compressed.readableBytes();
+ uncompressedDataLen += uncompressLength;
+ compressedDataLen += compressedLen;
+ shuffleWriteMetrics.incBytesWritten(compressedLen);
// add memory to indicate bytes which will be sent to shuffle server
inSendListBytes.addAndGet(wb.getMemoryUsed());
return new ShuffleBlockInfo(
shuffleId,
partitionId,
blockId,
- compressed.length,
+ compressedLen,
crc32,
compressed,
partitionAssignmentRetrieveFunc.apply(partitionId),
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java
index ac6ac9e27..41fc4d037 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java
@@ -19,7 +19,11 @@ package org.apache.spark.shuffle.writer;
import java.util.List;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.CompositeByteBuf;
+import io.netty.buffer.Unpooled;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -74,18 +78,33 @@ public class WriterBuffer {
return buffer == null || nextOffset + length > bufferSize;
}
+ @VisibleForTesting
public byte[] getData() {
- byte[] data = new byte[dataLength];
- int offset = 0;
- long start = System.currentTimeMillis();
- for (WrappedBuffer wrappedBuffer : buffers) {
- System.arraycopy(wrappedBuffer.getBuffer(), 0, data, offset,
wrappedBuffer.getSize());
- offset += wrappedBuffer.getSize();
+ ByteBuf buf = getDataAsByteBuf();
+ byte[] result = new byte[buf.readableBytes()];
+ buf.getBytes(0, result);
+ return result;
+ }
+
+ public ByteBuf getDataAsByteBuf() {
+ if (buffers.isEmpty()) {
+ if (buffer == null || nextOffset <= 0) {
+ return Unpooled.EMPTY_BUFFER;
+ }
+ return Unpooled.wrappedBuffer(buffer, 0, nextOffset);
+ }
+
+ CompositeByteBuf composite = Unpooled.compositeBuffer(buffers.size() + 1);
+ for (WrappedBuffer stagingBuffer : buffers) {
+ if (stagingBuffer.getSize() > 0) {
+ composite.addComponent(
+ true, Unpooled.wrappedBuffer(stagingBuffer.getBuffer(), 0,
stagingBuffer.getSize()));
+ }
+ }
+ if (buffer != null && nextOffset > 0) {
+ composite.addComponent(true, Unpooled.wrappedBuffer(buffer, 0,
nextOffset));
}
- // nextOffset is the length of current buffer used
- System.arraycopy(buffer, 0, data, offset, nextOffset);
- copyTime += System.currentTimeMillis() - start;
- return data;
+ return composite;
}
public int getDataLength() {
diff --git
a/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
b/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
index 1169b2931..8cfa196b9 100644
--- a/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
+++ b/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
@@ -71,6 +71,34 @@ public class ShuffleBlockInfo {
this.recordNumber = records;
}
+ public ShuffleBlockInfo(
+ int shuffleId,
+ int partitionId,
+ long blockId,
+ int length,
+ long crc,
+ ByteBuf data,
+ List<ShuffleServerInfo> shuffleServerInfos,
+ int uncompressLength,
+ long freeMemory,
+ long taskAttemptId,
+ Function<Integer, List<ShuffleServerInfo>>
partitionAssignmentRetrieveFunc,
+ long records) {
+ this(
+ shuffleId,
+ partitionId,
+ blockId,
+ length,
+ crc,
+ data,
+ shuffleServerInfos,
+ uncompressLength,
+ freeMemory,
+ taskAttemptId);
+ this.partitionAssignmentRetrieveFunc = partitionAssignmentRetrieveFunc;
+ this.recordNumber = records;
+ }
+
public ShuffleBlockInfo(
int shuffleId,
int partitionId,
@@ -97,6 +125,32 @@ public class ShuffleBlockInfo {
this.partitionAssignmentRetrieveFunc = partitionAssignmentRetrieveFunc;
}
+ public ShuffleBlockInfo(
+ int shuffleId,
+ int partitionId,
+ long blockId,
+ int length,
+ long crc,
+ ByteBuf data,
+ List<ShuffleServerInfo> shuffleServerInfos,
+ int uncompressLength,
+ long freeMemory,
+ long taskAttemptId,
+ Function<Integer, List<ShuffleServerInfo>>
partitionAssignmentRetrieveFunc) {
+ this(
+ shuffleId,
+ partitionId,
+ blockId,
+ length,
+ crc,
+ data,
+ shuffleServerInfos,
+ uncompressLength,
+ freeMemory,
+ taskAttemptId);
+ this.partitionAssignmentRetrieveFunc = partitionAssignmentRetrieveFunc;
+ }
+
protected ShuffleBlockInfo(
int shuffleId,
int partitionId,
diff --git
a/common/src/main/java/org/apache/uniffle/common/util/ChecksumUtils.java
b/common/src/main/java/org/apache/uniffle/common/util/ChecksumUtils.java
index 32ecf8676..7b89baaee 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/ChecksumUtils.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/ChecksumUtils.java
@@ -20,6 +20,8 @@ package org.apache.uniffle.common.util;
import java.nio.ByteBuffer;
import java.util.zip.CRC32;
+import io.netty.buffer.ByteBuf;
+
public class ChecksumUtils {
private static final int LENGTH_PER_CRC = 4 * 1024;
@@ -56,4 +58,25 @@ public class ChecksumUtils {
}
return crc32.getValue();
}
+
+ public static long getCrc32(ByteBuf byteBuf) {
+ final int offset = byteBuf.readerIndex();
+ final int length = byteBuf.readableBytes();
+ if (length == 0) {
+ return 0L;
+ }
+
+ // Avoid coalescing/copy for composite buffers by iterating over
nioBuffers.
+ if (byteBuf.nioBufferCount() == 1) {
+ return getCrc32(byteBuf.nioBuffer(offset, length));
+ }
+
+ CRC32 crc32 = new CRC32();
+ for (ByteBuffer bb : byteBuf.nioBuffers(offset, length)) {
+ // `nioBuffers` returns fresh ByteBuffer views; CRC32.update(ByteBuffer)
only advances
+ // the ByteBuffer position and won't affect the underlying ByteBuf
indices.
+ crc32.update(bb);
+ }
+ return crc32.getValue();
+ }
}
diff --git
a/common/src/test/java/org/apache/uniffle/common/util/ChecksumUtilsTest.java
b/common/src/test/java/org/apache/uniffle/common/util/ChecksumUtilsTest.java
index 100be0f77..82285c60f 100644
--- a/common/src/test/java/org/apache/uniffle/common/util/ChecksumUtilsTest.java
+++ b/common/src/test/java/org/apache/uniffle/common/util/ChecksumUtilsTest.java
@@ -26,10 +26,14 @@ import java.nio.file.Paths;
import java.util.Random;
import java.util.zip.CRC32;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.CompositeByteBuf;
+import io.netty.buffer.Unpooled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
public class ChecksumUtilsTest {
@@ -112,4 +116,58 @@ public class ChecksumUtilsTest {
directOffsetBuffer.put(data);
assertEquals(expectCrc, ChecksumUtils.getCrc32(directOffsetBuffer, offset,
length));
}
+
+ @Test
+ public void crc32ByteBufEmptyReadableBytesShouldReturnZero() {
+ ByteBuf byteBuf = Unpooled.buffer(16);
+ assertEquals(0, byteBuf.readableBytes());
+ assertEquals(0L, ChecksumUtils.getCrc32(byteBuf));
+ }
+
+ @Test
+ public void crc32ByteBufShouldRespectReaderIndexAndNotChangeIt() {
+ byte[] data = new byte[1024];
+ new Random().nextBytes(data);
+ ByteBuf byteBuf = Unpooled.wrappedBuffer(data);
+
+ int readerIndex = 17;
+ byteBuf.readerIndex(readerIndex);
+
+ CRC32 crc32 = new CRC32();
+ crc32.update(data, readerIndex, data.length - readerIndex);
+ long expected = crc32.getValue();
+
+ assertEquals(expected, ChecksumUtils.getCrc32(byteBuf));
+ assertEquals(readerIndex, byteBuf.readerIndex());
+ }
+
+ @Test
+ public void crc32CompositeByteBufShouldIterateOverNioBuffers() {
+ byte[] part1 = new byte[128];
+ byte[] part2 = new byte[256];
+ Random random = new Random();
+ random.nextBytes(part1);
+ random.nextBytes(part2);
+
+ CompositeByteBuf composite = Unpooled.compositeBuffer();
+ composite.addComponent(true, Unpooled.wrappedBuffer(part1));
+ composite.addComponent(true, Unpooled.wrappedBuffer(part2));
+
+ // Ensure this test hits the composite path (nioBufferCount > 1).
+ // Note: CompositeByteBuf.nioBufferCount() depends on
readerIndex/readableBytes, so check it
+ // here.
+ assertTrue(composite.nioBufferCount() > 1);
+
+ int skip = 13; // cross-component offsets are fine; expected CRC is
computed on readable bytes.
+ composite.skipBytes(skip);
+
+ CRC32 crc32 = new CRC32();
+ crc32.update(part1, skip, part1.length - skip);
+ crc32.update(part2, 0, part2.length);
+ long expected = crc32.getValue();
+
+ int readerIndex = composite.readerIndex();
+ assertEquals(expected, ChecksumUtils.getCrc32(composite));
+ assertEquals(readerIndex, composite.readerIndex());
+ }
}