This is an automated email from the ASF dual-hosted git repository. kerwinzhang pushed a commit to branch celeborn-755 in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
commit a09e0599e1065e6d7a92b43ad267f837e42d48b9 Author: xiyu.zk <[email protected]> AuthorDate: Fri Jun 30 17:27:31 2023 +0800 [CELEBORN-755] Support to decide whether to compress shuffle data through configuration --- .../apache/celeborn/client/ShuffleClientImpl.java | 21 +++++---- .../celeborn/client/read/RssInputStream.java | 54 ++++++++++++++-------- .../org/apache/celeborn/common/CelebornConf.scala | 9 ++++ 3 files changed, 57 insertions(+), 27 deletions(-) diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index e549400bb..faa446338 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -834,19 +834,24 @@ public class ShuffleClientImpl extends ShuffleClient { // increment batchId final int nextBatchId = pushState.nextBatchId(); - // compress data - final Compressor compressor = compressorThreadLocal.get(); - compressor.compress(data, offset, length); + int totalSize = data.length; + byte[] shuffleDataBuf = data; - final int compressedTotalSize = compressor.getCompressedTotalSize(); + if (conf.shuffleCompressionEnabled()) { + // compress data + final Compressor compressor = compressorThreadLocal.get(); + compressor.compress(data, offset, length); - final byte[] body = new byte[BATCH_HEADER_SIZE + compressedTotalSize]; + totalSize = compressor.getCompressedTotalSize(); + shuffleDataBuf = compressor.getCompressedBuffer(); + } + + final byte[] body = new byte[BATCH_HEADER_SIZE + totalSize]; Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET, mapId); Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 4, attemptId); Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 8, nextBatchId); - Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 12, compressedTotalSize); - System.arraycopy( - compressor.getCompressedBuffer(), 0, body, BATCH_HEADER_SIZE, compressedTotalSize); + Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 12, totalSize); + System.arraycopy(shuffleDataBuf, 0, body, BATCH_HEADER_SIZE, totalSize); if (doPush) { // check limit diff --git a/client/src/main/java/org/apache/celeborn/client/read/RssInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/RssInputStream.java index e6d3b1df5..f3615438f 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/RssInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/RssInputStream.java @@ -109,8 +109,8 @@ public abstract class RssInputStream extends InputStream { private final Map<Integer, Set<Integer>> batchesRead = new HashMap<>(); private byte[] compressedBuf; - private byte[] decompressedBuf; - private final Decompressor decompressor; + private byte[] rawDataBuf; + private Decompressor decompressor; private ByteBuf currentChunk; private PartitionReader currentReader; @@ -159,12 +159,15 @@ public abstract class RssInputStream extends InputStream { this.fetchExcludedWorkerExpireTimeout = conf.clientFetchExcludedWorkerExpireTimeout(); this.fetchExcludedWorkers = fetchExcludedWorkers; - int headerLen = Decompressor.getCompressionHeaderLength(conf); - int blockSize = conf.clientPushBufferMaxSize() + headerLen; - compressedBuf = new byte[blockSize]; - decompressedBuf = new byte[blockSize]; + int blockSize = conf.clientPushBufferMaxSize(); + if (conf.shuffleCompressionEnabled()) { + int headerLen = Decompressor.getCompressionHeaderLength(conf); + blockSize = conf.clientPushBufferMaxSize() + headerLen; + compressedBuf = new byte[blockSize]; - decompressor = Decompressor.getDecompressor(conf); + decompressor = Decompressor.getDecompressor(conf); + } + rawDataBuf = new byte[blockSize]; if (conf.clientPushReplicateEnabled()) { fetchChunkMaxRetry = conf.clientFetchMaxRetriesForEachReplica() * 2; @@ -414,7 +417,7 @@ public abstract class RssInputStream extends InputStream { @Override public int read() throws IOException { if (position < limit) { - int b = decompressedBuf[position]; + int b = rawDataBuf[position]; position++; return b & 0xFF; } @@ -426,7 +429,7 @@ public abstract class RssInputStream extends InputStream { if (position >= limit) { return read(); } else { - int b = decompressedBuf[position]; + int b = rawDataBuf[position]; position++; return b & 0xFF; } @@ -451,7 +454,7 @@ public abstract class RssInputStream extends InputStream { } int bytesToRead = Math.min(limit - position, len - readBytes); - System.arraycopy(decompressedBuf, position, b, off + readBytes, bytesToRead); + System.arraycopy(rawDataBuf, position, b, off + readBytes, bytesToRead); position += bytesToRead; readBytes += bytesToRead; } @@ -512,11 +515,20 @@ public abstract class RssInputStream extends InputStream { int attemptId = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 4); int batchId = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 8); int size = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 12); - if (size > compressedBuf.length) { - compressedBuf = new byte[size]; - } - currentChunk.readBytes(compressedBuf, 0, size); + if (conf.shuffleCompressionEnabled()) { + if (size > compressedBuf.length) { + compressedBuf = new byte[size]; + } + + currentChunk.readBytes(compressedBuf, 0, size); + } else { + if (size > rawDataBuf.length) { + rawDataBuf = new byte[size]; + } + + currentChunk.readBytes(rawDataBuf, 0, size); + } // de-duplicate if (attemptId == attempts[mapId]) { @@ -530,12 +542,16 @@ public abstract class RssInputStream extends InputStream { if (callback != null) { callback.incBytesRead(BATCH_HEADER_SIZE + size); } - // decompress data - int originalLength = decompressor.getOriginalLen(compressedBuf); - if (decompressedBuf.length < originalLength) { - decompressedBuf = new byte[originalLength]; + if (conf.shuffleCompressionEnabled()) { + // decompress data + int originalLength = decompressor.getOriginalLen(compressedBuf); + if (rawDataBuf.length < originalLength) { + rawDataBuf = new byte[originalLength]; + } + limit = decompressor.decompress(compressedBuf, rawDataBuf, 0); + } else { + limit = size; } - limit = decompressor.decompress(compressedBuf, decompressedBuf, 0); position = 0; hasData = true; break; diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 72503a576..10804794e 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -701,6 +701,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se // ////////////////////////////////////////////////////// // Shuffle Compression // // ////////////////////////////////////////////////////// + def shuffleCompressionEnabled: Boolean = get(SHUFFLE_COMPRESSION_ENABLED) def shuffleCompressionCodec: CompressionCodec = CompressionCodec.valueOf(get(SHUFFLE_COMPRESSION_CODEC)) def shuffleCompressionZstdCompressLevel: Int = get(SHUFFLE_COMPRESSION_ZSTD_LEVEL) @@ -2934,6 +2935,14 @@ object CelebornConf extends Logging { .checkValues(Set(PartitionSplitMode.SOFT.name, PartitionSplitMode.HARD.name)) .createWithDefault(PartitionSplitMode.SOFT.name) + val SHUFFLE_COMPRESSION_ENABLED: ConfigEntry[Boolean] = + buildConf("celeborn.client.shuffle.compression.enabled") + .categories("client") + .doc("whether to compress shuffle data.") + .version("0.3.0") + .booleanConf + .createWithDefault(true) + val SHUFFLE_COMPRESSION_CODEC: ConfigEntry[String] = buildConf("celeborn.client.shuffle.compression.codec") .withAlternative("celeborn.shuffle.compression.codec")
