This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 1f809ed26 [#2718] feat(spark): Eliminate copy in WriterBuffer when 
compression off for Gluten (#2720)
1f809ed26 is described below

commit 1f809ed263cb96dffded0988848977689c058d04
Author: Junfan Zhang <[email protected]>
AuthorDate: Fri Jan 30 16:46:43 2026 +0800

    [#2718] feat(spark): Eliminate copy in WriterBuffer when compression off 
for Gluten (#2720)
    
    ### What changes were proposed in this pull request?
    
    - This change introduces `CompositeByteBuf` in `WriterBuffer` to provide a 
zero-copy data view, which is especially beneficial in Gluten scenarios where 
compression is disabled and handled by the Gluten side.
    - In addition, the CRC32 generator is enhanced to operate directly on 
ByteBuf.
    
    For better to achieve zero-copy, we'd better to accept `ByteBuf` in codec, 
this could be finished in the next PRs.
    
    ### Why are the changes needed?
    
    for #2718
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Unit tests
---
 .../spark/shuffle/writer/WriteBufferManager.java   | 38 ++++++++------
 .../apache/spark/shuffle/writer/WriterBuffer.java  | 39 +++++++++++----
 .../apache/uniffle/common/ShuffleBlockInfo.java    | 54 ++++++++++++++++++++
 .../apache/uniffle/common/util/ChecksumUtils.java  | 23 +++++++++
 .../uniffle/common/util/ChecksumUtilsTest.java     | 58 ++++++++++++++++++++++
 5 files changed, 186 insertions(+), 26 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index 0321311a6..b50cc5d56 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -36,6 +36,8 @@ import scala.reflect.ManifestFactory$;
 import com.clearspring.analytics.util.Lists;
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Maps;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.memory.MemoryConsumer;
 import org.apache.spark.memory.MemoryMode;
@@ -423,10 +425,8 @@ public class WriteBufferManager extends MemoryConsumer {
 
   protected ShuffleBlockInfo createDeferredCompressedBlock(
       int partitionId, WriterBuffer writerBuffer) {
-    byte[] data = writerBuffer.getData();
-    final int uncompressLength = data.length;
+    final int uncompressLength = writerBuffer.getDataLength();
     final int memoryUsed = writerBuffer.getMemoryUsed();
-    final long records = writerBuffer.getRecordCount();
 
     this.blockCounter.incrementAndGet();
     this.uncompressedDataLen += uncompressLength;
@@ -435,12 +435,15 @@ public class WriteBufferManager extends MemoryConsumer {
     final long blockId =
         blockIdLayout.getBlockId(getNextSeqNo(partitionId), partitionId, 
taskAttemptId);
 
+    // todo: support ByteBuf compress directly to avoid copying
+    final byte[] rawData = writerBuffer.getData();
+
     Function<DeferredCompressedBlock, DeferredCompressedBlock> rebuildFunction 
=
         block -> {
-          byte[] compressed = data;
+          byte[] compressed = rawData;
           if (codec.isPresent()) {
             long start = System.currentTimeMillis();
-            compressed = codec.get().compress(data);
+            compressed = codec.get().compress(rawData);
             this.compressTime += System.currentTimeMillis() - start;
           }
           this.compressedDataLen += compressed.length;
@@ -451,9 +454,9 @@ public class WriteBufferManager extends MemoryConsumer {
           return block;
         };
 
-    int estimatedCompressedSize = data.length;
+    int estimatedCompressedSize = uncompressLength;
     if (codec.isPresent()) {
-      estimatedCompressedSize = codec.get().maxCompressedLength(data.length);
+      estimatedCompressedSize = 
codec.get().maxCompressedLength(uncompressLength);
     }
 
     return new DeferredCompressedBlock(
@@ -467,7 +470,7 @@ public class WriteBufferManager extends MemoryConsumer {
         partitionAssignmentRetrieveFunc,
         rebuildFunction,
         estimatedCompressedSize,
-        records);
+        writerBuffer.getRecordCount());
   }
 
   // transform records to shuffleBlock
@@ -476,28 +479,31 @@ public class WriteBufferManager extends MemoryConsumer {
       return createDeferredCompressedBlock(partitionId, wb);
     }
 
-    byte[] data = wb.getData();
-    final int uncompressLength = data.length;
-    byte[] compressed = data;
+    final int uncompressLength = wb.getDataLength();
+    final ByteBuf data = wb.getDataAsByteBuf();
+    ByteBuf compressed = data;
     if (codec.isPresent()) {
       long start = System.currentTimeMillis();
-      compressed = codec.get().compress(data);
+      // todo: support ByteBuf compress directly to avoid copying
+      byte[] compressedByteArr = codec.get().compress(wb.getData());
+      compressed = Unpooled.wrappedBuffer(compressedByteArr);
       compressTime += System.currentTimeMillis() - start;
     }
     final long crc32 = ChecksumUtils.getCrc32(compressed);
     final long blockId =
         blockIdLayout.getBlockId(getNextSeqNo(partitionId), partitionId, 
taskAttemptId);
     blockCounter.incrementAndGet();
-    uncompressedDataLen += data.length;
-    compressedDataLen += compressed.length;
-    shuffleWriteMetrics.incBytesWritten(compressed.length);
+    final int compressedLen = compressed.readableBytes();
+    uncompressedDataLen += uncompressLength;
+    compressedDataLen += compressedLen;
+    shuffleWriteMetrics.incBytesWritten(compressedLen);
     // add memory to indicate bytes which will be sent to shuffle server
     inSendListBytes.addAndGet(wb.getMemoryUsed());
     return new ShuffleBlockInfo(
         shuffleId,
         partitionId,
         blockId,
-        compressed.length,
+        compressedLen,
         crc32,
         compressed,
         partitionAssignmentRetrieveFunc.apply(partitionId),
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java
index ac6ac9e27..41fc4d037 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java
@@ -19,7 +19,11 @@ package org.apache.spark.shuffle.writer;
 
 import java.util.List;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Lists;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.CompositeByteBuf;
+import io.netty.buffer.Unpooled;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -74,18 +78,33 @@ public class WriterBuffer {
     return buffer == null || nextOffset + length > bufferSize;
   }
 
+  @VisibleForTesting
   public byte[] getData() {
-    byte[] data = new byte[dataLength];
-    int offset = 0;
-    long start = System.currentTimeMillis();
-    for (WrappedBuffer wrappedBuffer : buffers) {
-      System.arraycopy(wrappedBuffer.getBuffer(), 0, data, offset, 
wrappedBuffer.getSize());
-      offset += wrappedBuffer.getSize();
+    ByteBuf buf = getDataAsByteBuf();
+    byte[] result = new byte[buf.readableBytes()];
+    buf.getBytes(0, result);
+    return result;
+  }
+
+  public ByteBuf getDataAsByteBuf() {
+    if (buffers.isEmpty()) {
+      if (buffer == null || nextOffset <= 0) {
+        return Unpooled.EMPTY_BUFFER;
+      }
+      return Unpooled.wrappedBuffer(buffer, 0, nextOffset);
+    }
+
+    CompositeByteBuf composite = Unpooled.compositeBuffer(buffers.size() + 1);
+    for (WrappedBuffer stagingBuffer : buffers) {
+      if (stagingBuffer.getSize() > 0) {
+        composite.addComponent(
+            true, Unpooled.wrappedBuffer(stagingBuffer.getBuffer(), 0, 
stagingBuffer.getSize()));
+      }
+    }
+    if (buffer != null && nextOffset > 0) {
+      composite.addComponent(true, Unpooled.wrappedBuffer(buffer, 0, 
nextOffset));
     }
-    // nextOffset is the length of current buffer used
-    System.arraycopy(buffer, 0, data, offset, nextOffset);
-    copyTime += System.currentTimeMillis() - start;
-    return data;
+    return composite;
   }
 
   public int getDataLength() {
diff --git 
a/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java 
b/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
index 1169b2931..8cfa196b9 100644
--- a/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
+++ b/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
@@ -71,6 +71,34 @@ public class ShuffleBlockInfo {
     this.recordNumber = records;
   }
 
+  public ShuffleBlockInfo(
+      int shuffleId,
+      int partitionId,
+      long blockId,
+      int length,
+      long crc,
+      ByteBuf data,
+      List<ShuffleServerInfo> shuffleServerInfos,
+      int uncompressLength,
+      long freeMemory,
+      long taskAttemptId,
+      Function<Integer, List<ShuffleServerInfo>> 
partitionAssignmentRetrieveFunc,
+      long records) {
+    this(
+        shuffleId,
+        partitionId,
+        blockId,
+        length,
+        crc,
+        data,
+        shuffleServerInfos,
+        uncompressLength,
+        freeMemory,
+        taskAttemptId);
+    this.partitionAssignmentRetrieveFunc = partitionAssignmentRetrieveFunc;
+    this.recordNumber = records;
+  }
+
   public ShuffleBlockInfo(
       int shuffleId,
       int partitionId,
@@ -97,6 +125,32 @@ public class ShuffleBlockInfo {
     this.partitionAssignmentRetrieveFunc = partitionAssignmentRetrieveFunc;
   }
 
+  public ShuffleBlockInfo(
+      int shuffleId,
+      int partitionId,
+      long blockId,
+      int length,
+      long crc,
+      ByteBuf data,
+      List<ShuffleServerInfo> shuffleServerInfos,
+      int uncompressLength,
+      long freeMemory,
+      long taskAttemptId,
+      Function<Integer, List<ShuffleServerInfo>> 
partitionAssignmentRetrieveFunc) {
+    this(
+        shuffleId,
+        partitionId,
+        blockId,
+        length,
+        crc,
+        data,
+        shuffleServerInfos,
+        uncompressLength,
+        freeMemory,
+        taskAttemptId);
+    this.partitionAssignmentRetrieveFunc = partitionAssignmentRetrieveFunc;
+  }
+
   protected ShuffleBlockInfo(
       int shuffleId,
       int partitionId,
diff --git 
a/common/src/main/java/org/apache/uniffle/common/util/ChecksumUtils.java 
b/common/src/main/java/org/apache/uniffle/common/util/ChecksumUtils.java
index 32ecf8676..7b89baaee 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/ChecksumUtils.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/ChecksumUtils.java
@@ -20,6 +20,8 @@ package org.apache.uniffle.common.util;
 import java.nio.ByteBuffer;
 import java.util.zip.CRC32;
 
+import io.netty.buffer.ByteBuf;
+
 public class ChecksumUtils {
 
   private static final int LENGTH_PER_CRC = 4 * 1024;
@@ -56,4 +58,25 @@ public class ChecksumUtils {
     }
     return crc32.getValue();
   }
+
+  public static long getCrc32(ByteBuf byteBuf) {
+    final int offset = byteBuf.readerIndex();
+    final int length = byteBuf.readableBytes();
+    if (length == 0) {
+      return 0L;
+    }
+
+    // Avoid coalescing/copy for composite buffers by iterating over 
nioBuffers.
+    if (byteBuf.nioBufferCount() == 1) {
+      return getCrc32(byteBuf.nioBuffer(offset, length));
+    }
+
+    CRC32 crc32 = new CRC32();
+    for (ByteBuffer bb : byteBuf.nioBuffers(offset, length)) {
+      // `nioBuffers` returns fresh ByteBuffer views; CRC32.update(ByteBuffer) 
only advances
+      // the ByteBuffer position and won't affect the underlying ByteBuf 
indices.
+      crc32.update(bb);
+    }
+    return crc32.getValue();
+  }
 }
diff --git 
a/common/src/test/java/org/apache/uniffle/common/util/ChecksumUtilsTest.java 
b/common/src/test/java/org/apache/uniffle/common/util/ChecksumUtilsTest.java
index 100be0f77..82285c60f 100644
--- a/common/src/test/java/org/apache/uniffle/common/util/ChecksumUtilsTest.java
+++ b/common/src/test/java/org/apache/uniffle/common/util/ChecksumUtilsTest.java
@@ -26,10 +26,14 @@ import java.nio.file.Paths;
 import java.util.Random;
 import java.util.zip.CRC32;
 
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.CompositeByteBuf;
+import io.netty.buffer.Unpooled;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.io.TempDir;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
 
 public class ChecksumUtilsTest {
 
@@ -112,4 +116,58 @@ public class ChecksumUtilsTest {
     directOffsetBuffer.put(data);
     assertEquals(expectCrc, ChecksumUtils.getCrc32(directOffsetBuffer, offset, 
length));
   }
+
+  @Test
+  public void crc32ByteBufEmptyReadableBytesShouldReturnZero() {
+    ByteBuf byteBuf = Unpooled.buffer(16);
+    assertEquals(0, byteBuf.readableBytes());
+    assertEquals(0L, ChecksumUtils.getCrc32(byteBuf));
+  }
+
+  @Test
+  public void crc32ByteBufShouldRespectReaderIndexAndNotChangeIt() {
+    byte[] data = new byte[1024];
+    new Random().nextBytes(data);
+    ByteBuf byteBuf = Unpooled.wrappedBuffer(data);
+
+    int readerIndex = 17;
+    byteBuf.readerIndex(readerIndex);
+
+    CRC32 crc32 = new CRC32();
+    crc32.update(data, readerIndex, data.length - readerIndex);
+    long expected = crc32.getValue();
+
+    assertEquals(expected, ChecksumUtils.getCrc32(byteBuf));
+    assertEquals(readerIndex, byteBuf.readerIndex());
+  }
+
+  @Test
+  public void crc32CompositeByteBufShouldIterateOverNioBuffers() {
+    byte[] part1 = new byte[128];
+    byte[] part2 = new byte[256];
+    Random random = new Random();
+    random.nextBytes(part1);
+    random.nextBytes(part2);
+
+    CompositeByteBuf composite = Unpooled.compositeBuffer();
+    composite.addComponent(true, Unpooled.wrappedBuffer(part1));
+    composite.addComponent(true, Unpooled.wrappedBuffer(part2));
+
+    // Ensure this test hits the composite path (nioBufferCount > 1).
+    // Note: CompositeByteBuf.nioBufferCount() depends on 
readerIndex/readableBytes, so check it
+    // here.
+    assertTrue(composite.nioBufferCount() > 1);
+
+    int skip = 13; // cross-component offsets are fine; expected CRC is 
computed on readable bytes.
+    composite.skipBytes(skip);
+
+    CRC32 crc32 = new CRC32();
+    crc32.update(part1, skip, part1.length - skip);
+    crc32.update(part2, 0, part2.length);
+    long expected = crc32.getValue();
+
+    int readerIndex = composite.readerIndex();
+    assertEquals(expected, ChecksumUtils.getCrc32(composite));
+    assertEquals(readerIndex, composite.readerIndex());
+  }
 }

Reply via email to