This is an automated email from the ASF dual-hosted git repository.

kerwinzhang pushed a commit to branch celeborn-755
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git

commit a09e0599e1065e6d7a92b43ad267f837e42d48b9
Author: xiyu.zk <[email protected]>
AuthorDate: Fri Jun 30 17:27:31 2023 +0800

    [CELEBORN-755] Support to decide whether to compress shuffle data through 
configuration
---
 .../apache/celeborn/client/ShuffleClientImpl.java  | 21 +++++----
 .../celeborn/client/read/RssInputStream.java       | 54 ++++++++++++++--------
 .../org/apache/celeborn/common/CelebornConf.scala  |  9 ++++
 3 files changed, 57 insertions(+), 27 deletions(-)

diff --git 
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index e549400bb..faa446338 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -834,19 +834,24 @@ public class ShuffleClientImpl extends ShuffleClient {
     // increment batchId
     final int nextBatchId = pushState.nextBatchId();
 
-    // compress data
-    final Compressor compressor = compressorThreadLocal.get();
-    compressor.compress(data, offset, length);
+    int totalSize = data.length;
+    byte[] shuffleDataBuf = data;
 
-    final int compressedTotalSize = compressor.getCompressedTotalSize();
+    if (conf.shuffleCompressionEnabled()) {
+      // compress data
+      final Compressor compressor = compressorThreadLocal.get();
+      compressor.compress(data, offset, length);
 
-    final byte[] body = new byte[BATCH_HEADER_SIZE + compressedTotalSize];
+      totalSize = compressor.getCompressedTotalSize();
+      shuffleDataBuf = compressor.getCompressedBuffer();
+    }
+
+    final byte[] body = new byte[BATCH_HEADER_SIZE + totalSize];
     Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET, mapId);
     Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 4, attemptId);
     Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 8, nextBatchId);
-    Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 12, 
compressedTotalSize);
-    System.arraycopy(
-        compressor.getCompressedBuffer(), 0, body, BATCH_HEADER_SIZE, 
compressedTotalSize);
+    Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 12, totalSize);
+    System.arraycopy(shuffleDataBuf, 0, body, BATCH_HEADER_SIZE, totalSize);
 
     if (doPush) {
       // check limit
diff --git 
a/client/src/main/java/org/apache/celeborn/client/read/RssInputStream.java 
b/client/src/main/java/org/apache/celeborn/client/read/RssInputStream.java
index e6d3b1df5..f3615438f 100644
--- a/client/src/main/java/org/apache/celeborn/client/read/RssInputStream.java
+++ b/client/src/main/java/org/apache/celeborn/client/read/RssInputStream.java
@@ -109,8 +109,8 @@ public abstract class RssInputStream extends InputStream {
     private final Map<Integer, Set<Integer>> batchesRead = new HashMap<>();
 
     private byte[] compressedBuf;
-    private byte[] decompressedBuf;
-    private final Decompressor decompressor;
+    private byte[] rawDataBuf;
+    private Decompressor decompressor;
 
     private ByteBuf currentChunk;
     private PartitionReader currentReader;
@@ -159,12 +159,15 @@ public abstract class RssInputStream extends InputStream {
       this.fetchExcludedWorkerExpireTimeout = 
conf.clientFetchExcludedWorkerExpireTimeout();
       this.fetchExcludedWorkers = fetchExcludedWorkers;
 
-      int headerLen = Decompressor.getCompressionHeaderLength(conf);
-      int blockSize = conf.clientPushBufferMaxSize() + headerLen;
-      compressedBuf = new byte[blockSize];
-      decompressedBuf = new byte[blockSize];
+      int blockSize = conf.clientPushBufferMaxSize();
+      if (conf.shuffleCompressionEnabled()) {
+        int headerLen = Decompressor.getCompressionHeaderLength(conf);
+        blockSize = conf.clientPushBufferMaxSize() + headerLen;
+        compressedBuf = new byte[blockSize];
 
-      decompressor = Decompressor.getDecompressor(conf);
+        decompressor = Decompressor.getDecompressor(conf);
+      }
+      rawDataBuf = new byte[blockSize];
 
       if (conf.clientPushReplicateEnabled()) {
         fetchChunkMaxRetry = conf.clientFetchMaxRetriesForEachReplica() * 2;
@@ -414,7 +417,7 @@ public abstract class RssInputStream extends InputStream {
     @Override
     public int read() throws IOException {
       if (position < limit) {
-        int b = decompressedBuf[position];
+        int b = rawDataBuf[position];
         position++;
         return b & 0xFF;
       }
@@ -426,7 +429,7 @@ public abstract class RssInputStream extends InputStream {
       if (position >= limit) {
         return read();
       } else {
-        int b = decompressedBuf[position];
+        int b = rawDataBuf[position];
         position++;
         return b & 0xFF;
       }
@@ -451,7 +454,7 @@ public abstract class RssInputStream extends InputStream {
         }
 
         int bytesToRead = Math.min(limit - position, len - readBytes);
-        System.arraycopy(decompressedBuf, position, b, off + readBytes, 
bytesToRead);
+        System.arraycopy(rawDataBuf, position, b, off + readBytes, 
bytesToRead);
         position += bytesToRead;
         readBytes += bytesToRead;
       }
@@ -512,11 +515,20 @@ public abstract class RssInputStream extends InputStream {
         int attemptId = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 
4);
         int batchId = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 8);
         int size = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 12);
-        if (size > compressedBuf.length) {
-          compressedBuf = new byte[size];
-        }
 
-        currentChunk.readBytes(compressedBuf, 0, size);
+        if (conf.shuffleCompressionEnabled()) {
+          if (size > compressedBuf.length) {
+            compressedBuf = new byte[size];
+          }
+
+          currentChunk.readBytes(compressedBuf, 0, size);
+        } else {
+          if (size > rawDataBuf.length) {
+            rawDataBuf = new byte[size];
+          }
+
+          currentChunk.readBytes(rawDataBuf, 0, size);
+        }
 
         // de-duplicate
         if (attemptId == attempts[mapId]) {
@@ -530,12 +542,16 @@ public abstract class RssInputStream extends InputStream {
             if (callback != null) {
               callback.incBytesRead(BATCH_HEADER_SIZE + size);
             }
-            // decompress data
-            int originalLength = decompressor.getOriginalLen(compressedBuf);
-            if (decompressedBuf.length < originalLength) {
-              decompressedBuf = new byte[originalLength];
+            if (conf.shuffleCompressionEnabled()) {
+              // decompress data
+              int originalLength = decompressor.getOriginalLen(compressedBuf);
+              if (rawDataBuf.length < originalLength) {
+                rawDataBuf = new byte[originalLength];
+              }
+              limit = decompressor.decompress(compressedBuf, rawDataBuf, 0);
+            } else {
+              limit = size;
             }
-            limit = decompressor.decompress(compressedBuf, decompressedBuf, 0);
             position = 0;
             hasData = true;
             break;
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala 
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index 72503a576..10804794e 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -701,6 +701,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable 
with Logging with Se
   // //////////////////////////////////////////////////////
   //               Shuffle Compression                   //
   // //////////////////////////////////////////////////////
+  def shuffleCompressionEnabled: Boolean = get(SHUFFLE_COMPRESSION_ENABLED)
   def shuffleCompressionCodec: CompressionCodec =
     CompressionCodec.valueOf(get(SHUFFLE_COMPRESSION_CODEC))
   def shuffleCompressionZstdCompressLevel: Int = 
get(SHUFFLE_COMPRESSION_ZSTD_LEVEL)
@@ -2934,6 +2935,14 @@ object CelebornConf extends Logging {
       .checkValues(Set(PartitionSplitMode.SOFT.name, 
PartitionSplitMode.HARD.name))
       .createWithDefault(PartitionSplitMode.SOFT.name)
 
+  val SHUFFLE_COMPRESSION_ENABLED: ConfigEntry[Boolean] =
+    buildConf("celeborn.client.shuffle.compression.enabled")
+      .categories("client")
+      .doc("whether to compress shuffle data.")
+      .version("0.3.0")
+      .booleanConf
+      .createWithDefault(true)
+
   val SHUFFLE_COMPRESSION_CODEC: ConfigEntry[String] =
     buildConf("celeborn.client.shuffle.compression.codec")
       .withAlternative("celeborn.shuffle.compression.codec")

Reply via email to