This is an automated email from the ASF dual-hosted git repository.

chengpan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 381165d4e [CELEBORN-755] Support disable shuffle compression
381165d4e is described below

commit 381165d4e7ef1229bd88f08381cbd96f5cc91414
Author: xiyu.zk <[email protected]>
AuthorDate: Sat Jul 1 00:03:50 2023 +0800

    [CELEBORN-755] Support disable shuffle compression
    
    ### What changes were proposed in this pull request?
    Support to decide whether to compress shuffle data through configuration.
    
    ### Why are the changes needed?
    Currently, Celeborn compresses all shuffle data, but for example, the 
shuffle data of Gluten has already been compressed. In this case, no additional 
compression is required. Therefore, configuration needs to be provided for 
users to decide whether to use Celeborn’s compression according to the actual 
situation.
    
    ### Does this PR introduce _any_ user-facing change?
    no.
    
    Closes #1669 from kerwin-zk/celeborn-755.
    
    Authored-by: xiyu.zk <[email protected]>
    Signed-off-by: Cheng Pan <[email protected]>
---
 .../apache/celeborn/client/ShuffleClientImpl.java  | 25 +++++++---
 .../celeborn/client/read/RssInputStream.java       | 58 +++++++++++++++-------
 .../celeborn/client/ShuffleClientSuiteJ.java       | 50 +++++++++----------
 .../celeborn/common/protocol/CompressionCodec.java |  3 +-
 .../org/apache/celeborn/common/CelebornConf.scala  |  7 ++-
 docs/configuration/client.md                       |  2 +-
 .../cluster/ClusterReadWriteTestWithNONE.scala     | 13 +++--
 7 files changed, 97 insertions(+), 61 deletions(-)

diff --git 
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index e549400bb..0e943dacd 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -101,6 +101,7 @@ public class ShuffleClientImpl extends ShuffleClient {
   protected final Map<String, PushState> pushStates = 
JavaUtils.newConcurrentHashMap();
 
   private final boolean pushExcludeWorkerOnFailureEnabled;
+  private final boolean shuffleCompressionEnabled;
   private final Set<String> pushExcludedWorkers = 
ConcurrentHashMap.newKeySet();
   private final ConcurrentHashMap<String, Long> fetchExcludedWorkers =
       JavaUtils.newConcurrentHashMap();
@@ -164,6 +165,7 @@ public class ShuffleClientImpl extends ShuffleClient {
     testRetryRevive = conf.testRetryRevive();
     pushBufferMaxSize = conf.clientPushBufferMaxSize();
     pushExcludeWorkerOnFailureEnabled = 
conf.clientPushExcludeWorkerOnFailureEnabled();
+    shuffleCompressionEnabled = 
!conf.shuffleCompressionCodec().equals(CompressionCodec.NONE);
     if (conf.clientPushReplicateEnabled()) {
       pushDataTimeout = conf.pushDataTimeoutMs() * 2;
     } else {
@@ -834,19 +836,26 @@ public class ShuffleClientImpl extends ShuffleClient {
     // increment batchId
     final int nextBatchId = pushState.nextBatchId();
 
-    // compress data
-    final Compressor compressor = compressorThreadLocal.get();
-    compressor.compress(data, offset, length);
+    int totalSize = length;
+    byte[] shuffleDataBuf = new byte[length];
 
-    final int compressedTotalSize = compressor.getCompressedTotalSize();
+    if (shuffleCompressionEnabled) {
+      // compress data
+      final Compressor compressor = compressorThreadLocal.get();
+      compressor.compress(data, offset, length);
 
-    final byte[] body = new byte[BATCH_HEADER_SIZE + compressedTotalSize];
+      totalSize = compressor.getCompressedTotalSize();
+      shuffleDataBuf = compressor.getCompressedBuffer();
+    } else {
+      System.arraycopy(data, offset, shuffleDataBuf, 0, length);
+    }
+
+    final byte[] body = new byte[BATCH_HEADER_SIZE + totalSize];
     Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET, mapId);
     Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 4, attemptId);
     Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 8, nextBatchId);
-    Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 12, 
compressedTotalSize);
-    System.arraycopy(
-        compressor.getCompressedBuffer(), 0, body, BATCH_HEADER_SIZE, 
compressedTotalSize);
+    Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 12, totalSize);
+    System.arraycopy(shuffleDataBuf, 0, body, BATCH_HEADER_SIZE, totalSize);
 
     if (doPush) {
       // check limit
diff --git 
a/client/src/main/java/org/apache/celeborn/client/read/RssInputStream.java 
b/client/src/main/java/org/apache/celeborn/client/read/RssInputStream.java
index e6d3b1df5..bcc639780 100644
--- a/client/src/main/java/org/apache/celeborn/client/read/RssInputStream.java
+++ b/client/src/main/java/org/apache/celeborn/client/read/RssInputStream.java
@@ -36,6 +36,7 @@ import org.apache.celeborn.common.CelebornConf;
 import org.apache.celeborn.common.exception.CelebornIOException;
 import org.apache.celeborn.common.network.client.TransportClientFactory;
 import org.apache.celeborn.common.network.util.TransportConf;
+import org.apache.celeborn.common.protocol.CompressionCodec;
 import org.apache.celeborn.common.protocol.PartitionLocation;
 import org.apache.celeborn.common.protocol.StorageInfo;
 import org.apache.celeborn.common.protocol.TransportModuleConstants;
@@ -109,8 +110,8 @@ public abstract class RssInputStream extends InputStream {
     private final Map<Integer, Set<Integer>> batchesRead = new HashMap<>();
 
     private byte[] compressedBuf;
-    private byte[] decompressedBuf;
-    private final Decompressor decompressor;
+    private byte[] rawDataBuf;
+    private Decompressor decompressor;
 
     private ByteBuf currentChunk;
     private PartitionReader currentReader;
@@ -131,6 +132,7 @@ public abstract class RssInputStream extends InputStream {
 
     private boolean pushReplicateEnabled;
     private boolean fetchExcludeWorkerOnFailureEnabled;
+    private boolean shuffleCompressionEnabled;
     private long fetchExcludedWorkerExpireTimeout;
     private final ConcurrentHashMap<String, Long> fetchExcludedWorkers;
 
@@ -156,15 +158,20 @@ public abstract class RssInputStream extends InputStream {
       this.rangeReadFilter = conf.shuffleRangeReadFilterEnabled();
       this.pushReplicateEnabled = conf.clientPushReplicateEnabled();
       this.fetchExcludeWorkerOnFailureEnabled = 
conf.clientFetchExcludeWorkerOnFailureEnabled();
+      this.shuffleCompressionEnabled =
+          !conf.shuffleCompressionCodec().equals(CompressionCodec.NONE);
       this.fetchExcludedWorkerExpireTimeout = 
conf.clientFetchExcludedWorkerExpireTimeout();
       this.fetchExcludedWorkers = fetchExcludedWorkers;
 
-      int headerLen = Decompressor.getCompressionHeaderLength(conf);
-      int blockSize = conf.clientPushBufferMaxSize() + headerLen;
-      compressedBuf = new byte[blockSize];
-      decompressedBuf = new byte[blockSize];
+      int blockSize = conf.clientPushBufferMaxSize();
+      if (shuffleCompressionEnabled) {
+        int headerLen = Decompressor.getCompressionHeaderLength(conf);
+        blockSize = conf.clientPushBufferMaxSize() + headerLen;
+        compressedBuf = new byte[blockSize];
 
-      decompressor = Decompressor.getDecompressor(conf);
+        decompressor = Decompressor.getDecompressor(conf);
+      }
+      rawDataBuf = new byte[blockSize];
 
       if (conf.clientPushReplicateEnabled()) {
         fetchChunkMaxRetry = conf.clientFetchMaxRetriesForEachReplica() * 2;
@@ -414,7 +421,7 @@ public abstract class RssInputStream extends InputStream {
     @Override
     public int read() throws IOException {
       if (position < limit) {
-        int b = decompressedBuf[position];
+        int b = rawDataBuf[position];
         position++;
         return b & 0xFF;
       }
@@ -426,7 +433,7 @@ public abstract class RssInputStream extends InputStream {
       if (position >= limit) {
         return read();
       } else {
-        int b = decompressedBuf[position];
+        int b = rawDataBuf[position];
         position++;
         return b & 0xFF;
       }
@@ -451,7 +458,7 @@ public abstract class RssInputStream extends InputStream {
         }
 
         int bytesToRead = Math.min(limit - position, len - readBytes);
-        System.arraycopy(decompressedBuf, position, b, off + readBytes, 
bytesToRead);
+        System.arraycopy(rawDataBuf, position, b, off + readBytes, 
bytesToRead);
         position += bytesToRead;
         readBytes += bytesToRead;
       }
@@ -512,11 +519,20 @@ public abstract class RssInputStream extends InputStream {
         int attemptId = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 
4);
         int batchId = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 8);
         int size = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 12);
-        if (size > compressedBuf.length) {
-          compressedBuf = new byte[size];
-        }
 
-        currentChunk.readBytes(compressedBuf, 0, size);
+        if (shuffleCompressionEnabled) {
+          if (size > compressedBuf.length) {
+            compressedBuf = new byte[size];
+          }
+
+          currentChunk.readBytes(compressedBuf, 0, size);
+        } else {
+          if (size > rawDataBuf.length) {
+            rawDataBuf = new byte[size];
+          }
+
+          currentChunk.readBytes(rawDataBuf, 0, size);
+        }
 
         // de-duplicate
         if (attemptId == attempts[mapId]) {
@@ -530,12 +546,16 @@ public abstract class RssInputStream extends InputStream {
             if (callback != null) {
               callback.incBytesRead(BATCH_HEADER_SIZE + size);
             }
-            // decompress data
-            int originalLength = decompressor.getOriginalLen(compressedBuf);
-            if (decompressedBuf.length < originalLength) {
-              decompressedBuf = new byte[originalLength];
+            if (shuffleCompressionEnabled) {
+              // decompress data
+              int originalLength = decompressor.getOriginalLen(compressedBuf);
+              if (rawDataBuf.length < originalLength) {
+                rawDataBuf = new byte[originalLength];
+              }
+              limit = decompressor.decompress(compressedBuf, rawDataBuf, 0);
+            } else {
+              limit = size;
             }
-            limit = decompressor.decompress(compressedBuf, decompressedBuf, 0);
             position = 0;
             hasData = true;
             break;
diff --git 
a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java 
b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
index da43f85c8..9ffc38579 100644
--- a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
+++ b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
@@ -105,11 +105,14 @@ public class ShuffleClientSuiteJ {
               1,
               1);
 
-      Compressor compressor = Compressor.getCompressor(conf);
-      compressor.compress(TEST_BUF1, 0, TEST_BUF1.length);
-      final int compressedTotalSize = compressor.getCompressedTotalSize();
-
-      assert (pushDataLen == compressedTotalSize + BATCH_HEADER_SIZE);
+      if (codec.equals(CompressionCodec.NONE)) {
+        assert (pushDataLen == TEST_BUF1.length + BATCH_HEADER_SIZE);
+      } else {
+        Compressor compressor = Compressor.getCompressor(conf);
+        compressor.compress(TEST_BUF1, 0, TEST_BUF1.length);
+        final int compressedTotalSize = compressor.getCompressedTotalSize();
+        assert (pushDataLen == compressedTotalSize + BATCH_HEADER_SIZE);
+      }
     }
   }
 
@@ -130,22 +133,14 @@ public class ShuffleClientSuiteJ {
               1,
               1);
 
-      Compressor compressor = Compressor.getCompressor(conf);
-      compressor.compress(TEST_BUF1, 0, TEST_BUF1.length);
-      final int compressedTotalSize = compressor.getCompressedTotalSize();
-
-      shuffleClient.mergeData(
-          TEST_SHUFFLE_ID,
-          TEST_ATTEMPT_ID,
-          TEST_ATTEMPT_ID,
-          TEST_REDUCRE_ID,
-          TEST_BUF1,
-          0,
-          TEST_BUF1.length,
-          1,
-          1);
-
-      assert (mergeSize == compressedTotalSize + BATCH_HEADER_SIZE);
+      if (codec.equals(CompressionCodec.NONE)) {
+        assert (mergeSize == TEST_BUF1.length + BATCH_HEADER_SIZE);
+      } else {
+        Compressor compressor = Compressor.getCompressor(conf);
+        compressor.compress(TEST_BUF1, 0, TEST_BUF1.length);
+        final int compressedTotalSize = compressor.getCompressedTotalSize();
+        assert (mergeSize == compressedTotalSize + BATCH_HEADER_SIZE);
+      }
 
       byte[] buf1k = 
RandomStringUtils.random(4000).getBytes(StandardCharsets.UTF_8);
       int largeMergeSize =
@@ -160,11 +155,14 @@ public class ShuffleClientSuiteJ {
               1,
               1);
 
-      compressor = Compressor.getCompressor(conf);
-      compressor.compress(buf1k, 0, buf1k.length);
-      int compressedTotalSize1 = compressor.getCompressedTotalSize();
-
-      assert (largeMergeSize == compressedTotalSize1 + BATCH_HEADER_SIZE);
+      if (codec.equals(CompressionCodec.NONE)) {
+        assert (largeMergeSize == buf1k.length + BATCH_HEADER_SIZE);
+      } else {
+        Compressor compressor = Compressor.getCompressor(conf);
+        compressor.compress(buf1k, 0, buf1k.length);
+        final int compressedTotalSize = compressor.getCompressedTotalSize();
+        assert (largeMergeSize == compressedTotalSize + BATCH_HEADER_SIZE);
+      }
     }
   }
 
diff --git 
a/common/src/main/java/org/apache/celeborn/common/protocol/CompressionCodec.java
 
b/common/src/main/java/org/apache/celeborn/common/protocol/CompressionCodec.java
index ec3c243c7..55983a11f 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/protocol/CompressionCodec.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/protocol/CompressionCodec.java
@@ -19,5 +19,6 @@ package org.apache.celeborn.common.protocol;
 
 public enum CompressionCodec {
   LZ4,
-  ZSTD;
+  ZSTD,
+  NONE;
 }
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala 
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index 72503a576..af71bab48 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -2939,11 +2939,14 @@ object CelebornConf extends Logging {
       .withAlternative("celeborn.shuffle.compression.codec")
       .withAlternative("remote-shuffle.job.compression.codec")
       .categories("client")
-      .doc("The codec used to compress shuffle data. By default, Celeborn 
provides two codecs: `lz4` and `zstd`.")
+      .doc("The codec used to compress shuffle data. By default, Celeborn 
provides three codecs: `lz4`, `zstd`, `none`.")
       .version("0.3.0")
       .stringConf
       .transform(_.toUpperCase(Locale.ROOT))
-      .checkValues(Set(CompressionCodec.LZ4.name, CompressionCodec.ZSTD.name))
+      .checkValues(Set(
+        CompressionCodec.LZ4.name,
+        CompressionCodec.ZSTD.name,
+        CompressionCodec.NONE.name))
       .createWithDefault(CompressionCodec.LZ4.name)
 
   val SHUFFLE_COMPRESSION_ZSTD_LEVEL: ConfigEntry[Int] =
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index 8ecf65746..d88c1b64b 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -81,7 +81,7 @@ license: |
 | celeborn.client.shuffle.batchHandleReleasePartition.enabled | true | When 
true, LifecycleManager will handle release partition request in batch. 
Otherwise, LifecycleManager will process release partition request immediately 
| 0.3.0 | 
 | celeborn.client.shuffle.batchHandleReleasePartition.interval | 5s | Interval 
for LifecycleManager to schedule handling release partition requests in batch. 
| 0.3.0 | 
 | celeborn.client.shuffle.batchHandleReleasePartition.threads | 8 | Threads 
number for LifecycleManager to handle release partition request in batch. | 
0.3.0 | 
-| celeborn.client.shuffle.compression.codec | LZ4 | The codec used to compress 
shuffle data. By default, Celeborn provides two codecs: `lz4` and `zstd`. | 
0.3.0 | 
+| celeborn.client.shuffle.compression.codec | LZ4 | The codec used to compress 
shuffle data. By default, Celeborn provides three codecs: `lz4`, `zstd`, 
`none`. | 0.3.0 | 
 | celeborn.client.shuffle.compression.zstd.level | 1 | Compression level for 
Zstd compression codec, its value should be an integer between -5 and 22. 
Increasing the compression level will result in better compression at the 
expense of more CPU and memory. | 0.3.0 | 
 | celeborn.client.shuffle.expired.checkInterval | 60s | Interval for client to 
check expired shuffles. | 0.3.0 | 
 | celeborn.client.shuffle.manager.port | 0 | Port used by the LifecycleManager 
on the Driver. | 0.3.0 | 
diff --git 
a/common/src/main/java/org/apache/celeborn/common/protocol/CompressionCodec.java
 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ClusterReadWriteTestWithNONE.scala
similarity index 75%
copy from 
common/src/main/java/org/apache/celeborn/common/protocol/CompressionCodec.java
copy to 
worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ClusterReadWriteTestWithNONE.scala
index ec3c243c7..9b428fd2e 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/protocol/CompressionCodec.java
+++ 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ClusterReadWriteTestWithNONE.scala
@@ -15,9 +15,14 @@
  * limitations under the License.
  */
 
-package org.apache.celeborn.common.protocol;
+package org.apache.celeborn.service.deploy.cluster
+
+import org.apache.celeborn.common.protocol.CompressionCodec
+
+class ClusterReadWriteTestWithNONE extends ReadWriteTestBase {
+
+  test(s"test MiniCluster With NONE") {
+    testReadWriteByCode(CompressionCodec.NONE)
+  }
 
-public enum CompressionCodec {
-  LZ4,
-  ZSTD;
 }

Reply via email to