This is an automated email from the ASF dual-hosted git repository.
zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 0495fbca9 [#2056] feat(client): Add a NONE type to bypass
de/compression for gluten (#2057)
0495fbca9 is described below
commit 0495fbca9509e74d07eb370aae4a9a4acd7a6e23
Author: Zhen Wang <[email protected]>
AuthorDate: Tue Aug 20 10:48:11 2024 +0800
[#2056] feat(client): Add a NONE type to bypass de/compression for gluten
(#2057)
### What changes were proposed in this pull request?
add a NONE type to compression.codec
### Why are the changes needed?
Allow disabling rss client compression when spark.shuffle.compress is
enabled
Fix: #2056
### Does this PR introduce _any_ user-facing change?
yes. add a NONE compression type
### How was this patch tested?
existing unit tests
---
.../hadoop/mapred/SortWriteBufferManager.java | 5 ++--
.../hadoop/mapreduce/task/reduce/RssFetcher.java | 25 ++++++++++------
.../shuffle/reader/RssShuffleDataIterator.java | 13 ++++----
.../spark/shuffle/writer/WriteBufferManager.java | 9 +++---
.../shuffle/reader/AbstractRssReaderTest.java | 6 ++--
.../shuffle/reader/RssShuffleDataIteratorTest.java | 9 ++++--
.../shuffle/writer/WriteBufferManagerTest.java | 8 +++--
.../library/common/shuffle/impl/RssTezFetcher.java | 25 ++++++++++------
.../orderedgrouped/RssTezShuffleDataFetcher.java | 25 ++++++++++------
.../common/sort/buffer/WriteBufferManager.java | 5 ++--
.../apache/uniffle/common/compression/Codec.java | 14 +++++----
.../uniffle/common/util/ByteBufferUtils.java | 35 ++++++++++++++++++++++
.../common/compression/CompressionTest.java | 2 +-
13 files changed, 126 insertions(+), 55 deletions(-)
diff --git
a/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
b/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
index 1860fe856..b31766652 100644
---
a/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
+++
b/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
@@ -23,6 +23,7 @@ import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@@ -96,7 +97,7 @@ public class SortWriteBufferManager<K, V> {
private final long maxBufferSize;
private final ExecutorService sendExecutorService;
private final RssConf rssConf;
- private final Codec codec;
+ private final Optional<Codec> codec;
private final Task.CombinerRunner<K, V> combinerRunner;
public SortWriteBufferManager(
@@ -383,7 +384,7 @@ public class SortWriteBufferManager<K, V> {
int partitionId = wb.getPartitionId();
final int uncompressLength = data.length;
long start = System.currentTimeMillis();
- final byte[] compressed = codec.compress(data);
+ final byte[] compressed = codec.map(c -> c.compress(data)).orElse(data);
final long crc32 = ChecksumUtils.getCrc32(compressed);
compressTime += System.currentTimeMillis() - start;
final long blockId =
diff --git
a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssFetcher.java
b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssFetcher.java
index b07581a2e..0e41490f8 100644
---
a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssFetcher.java
+++
b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssFetcher.java
@@ -20,6 +20,7 @@ package org.apache.hadoop.mapreduce.task.reduce;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.text.DecimalFormat;
+import java.util.Optional;
import com.google.common.annotations.VisibleForTesting;
import org.apache.hadoop.mapred.Counters;
@@ -38,6 +39,7 @@ import
org.apache.uniffle.client.response.CompressedShuffleBlock;
import org.apache.uniffle.common.compression.Codec;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.util.ByteBufferUtils;
import org.apache.uniffle.common.util.ByteUnit;
public class RssFetcher<K, V> {
@@ -90,7 +92,7 @@ public class RssFetcher<K, V> {
private int waitCount = 0;
private byte[] uncompressedData = null;
private RssConf rssConf;
- private Codec codec;
+ private Optional<Codec> codec;
RssFetcher(
JobConf job,
@@ -161,14 +163,19 @@ public class RssFetcher<K, V> {
// uncompress the block
if (!hasPendingData && compressedData != null) {
- final long startDecompress = System.currentTimeMillis();
- int uncompressedLen = compressedBlock.getUncompressLength();
- ByteBuffer decompressedBuffer = ByteBuffer.allocate(uncompressedLen);
- codec.decompress(compressedData, uncompressedLen, decompressedBuffer, 0);
- uncompressedData = decompressedBuffer.array();
- unCompressionLength += compressedBlock.getUncompressLength();
- long decompressDuration = System.currentTimeMillis() - startDecompress;
- decompressTime += decompressDuration;
+ if (codec.isPresent()) {
+ final long startDecompress = System.currentTimeMillis();
+ int uncompressedLen = compressedBlock.getUncompressLength();
+ ByteBuffer decompressedBuffer = ByteBuffer.allocate(uncompressedLen);
+ codec.get().decompress(compressedData, uncompressedLen,
decompressedBuffer, 0);
+ uncompressedData = decompressedBuffer.array();
+ unCompressionLength += compressedBlock.getUncompressLength();
+ long decompressDuration = System.currentTimeMillis() - startDecompress;
+ decompressTime += decompressDuration;
+ } else {
+ uncompressedData = ByteBufferUtils.bufferToArray(compressedData);
+ unCompressionLength += uncompressedData.length;
+ }
}
if (uncompressedData != null) {
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
index 88b2d22d8..4f9900ce7 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
@@ -19,6 +19,7 @@ package org.apache.spark.shuffle.reader;
import java.io.IOException;
import java.nio.ByteBuffer;
+import java.util.Optional;
import scala.Product2;
import scala.Tuple2;
@@ -59,7 +60,7 @@ public class RssShuffleDataIterator<K, C> extends
AbstractIterator<Product2<K, C
private long totalRawBytesLength = 0;
private long unCompressedBytesLength = 0;
private ByteBuffer uncompressedData;
- private Codec codec;
+ private Optional<Codec> codec;
public RssShuffleDataIterator(
Serializer serializer,
@@ -74,7 +75,7 @@ public class RssShuffleDataIterator<K, C> extends
AbstractIterator<Product2<K, C
RssSparkConfig.SPARK_SHUFFLE_COMPRESS_KEY.substring(
RssSparkConfig.SPARK_RSS_CONFIG_PREFIX.length()),
RssSparkConfig.SPARK_SHUFFLE_COMPRESS_DEFAULT);
- this.codec = compress ? Codec.newInstance(rssConf) : null;
+ this.codec = compress ? Codec.newInstance(rssConf) : Optional.empty();
}
public Iterator<Tuple2<Object, Object>> createKVIterator(ByteBuffer data) {
@@ -131,7 +132,7 @@ public class RssShuffleDataIterator<K, C> extends
AbstractIterator<Product2<K, C
shuffleReadClient.checkProcessedBlockIds();
shuffleReadClient.logStatics();
String decInfo =
- codec == null
+ !codec.isPresent()
? "."
: (", "
+ decompressTime
@@ -160,7 +161,7 @@ public class RssShuffleDataIterator<K, C> extends
AbstractIterator<Product2<K, C
shuffleReadMetrics.incRemoteBytesRead(rawDataLength);
int uncompressedLen = rawBlock.getUncompressLength();
- if (codec != null) {
+ if (codec.isPresent()) {
if (uncompressedData == null
|| uncompressedData.capacity() < uncompressedLen
|| !isSameMemoryType(uncompressedData, rawData)) {
@@ -185,7 +186,7 @@ public class RssShuffleDataIterator<K, C> extends
AbstractIterator<Product2<K, C
}
uncompressedData.clear();
long startDecompress = System.currentTimeMillis();
- codec.decompress(rawData, uncompressedLen, uncompressedData, 0);
+ codec.get().decompress(rawData, uncompressedLen, uncompressedData, 0);
unCompressedBytesLength += uncompressedLen;
long decompressDuration = System.currentTimeMillis() - startDecompress;
decompressTime += decompressDuration;
@@ -210,7 +211,7 @@ public class RssShuffleDataIterator<K, C> extends
AbstractIterator<Product2<K, C
// Uncompressed data is released in this class, Compressed data is release
in the class
// ShuffleReadClientImpl
// So if codec is null, we don't release the data when the stream is closed
- if (codec != null) {
+ if (codec.isPresent()) {
RssUtils.releaseByteBuffer(uncompressedData);
}
if (shuffleReadClient != null) {
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index bfd929777..08eec1c2e 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -22,6 +22,7 @@ import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
@@ -91,7 +92,7 @@ public class WriteBufferManager extends MemoryConsumer {
private long uncompressedDataLen = 0;
private long requireMemoryInterval;
private int requireMemoryRetryMax;
- private Codec codec;
+ private Optional<Codec> codec;
private Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>>
spillFunc;
private long sendSizeLimit;
private boolean memorySpillEnabled;
@@ -159,7 +160,7 @@ public class WriteBufferManager extends MemoryConsumer {
RssSparkConfig.SPARK_SHUFFLE_COMPRESS_KEY.substring(
RssSparkConfig.SPARK_RSS_CONFIG_PREFIX.length()),
RssSparkConfig.SPARK_SHUFFLE_COMPRESS_DEFAULT);
- this.codec = compress ? Codec.newInstance(rssConf) : null;
+ this.codec = compress ? Codec.newInstance(rssConf) : Optional.empty();
this.spillFunc = spillFunc;
this.sendSizeLimit =
rssConf.get(RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMITATION);
this.memorySpillTimeoutSec =
rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_TIMEOUT);
@@ -384,9 +385,9 @@ public class WriteBufferManager extends MemoryConsumer {
byte[] data = wb.getData();
final int uncompressLength = data.length;
byte[] compressed = data;
- if (codec != null) {
+ if (codec.isPresent()) {
long start = System.currentTimeMillis();
- compressed = codec.compress(data);
+ compressed = codec.get().compress(data);
compressTime += System.currentTimeMillis() - start;
}
final long crc32 = ChecksumUtils.getCrc32(compressed);
diff --git
a/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/AbstractRssReaderTest.java
b/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/AbstractRssReaderTest.java
index f761c6ea6..7099fd9eb 100644
---
a/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/AbstractRssReaderTest.java
+++
b/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/AbstractRssReaderTest.java
@@ -19,6 +19,7 @@ package org.apache.spark.shuffle.reader;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
@@ -171,8 +172,9 @@ public abstract class AbstractRssReaderTest extends
HadoopTestBase {
protected ShufflePartitionedBlock createShuffleBlock(
byte[] data, long blockId, boolean compress) {
byte[] compressData = data;
- if (compress) {
- compressData = Codec.newInstance(new RssConf()).compress(data);
+ Optional<Codec> codec = Codec.newInstance(new RssConf());
+ if (compress && codec.isPresent()) {
+ compressData = codec.get().compress(data);
}
long crc = ChecksumUtils.getCrc32(compressData);
return new ShufflePartitionedBlock(
diff --git
a/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java
b/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java
index 3f6993c82..5550f67c0 100644
---
a/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java
+++
b/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java
@@ -20,6 +20,7 @@ package org.apache.spark.shuffle.reader;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
import java.util.stream.Stream;
import com.google.common.collect.Lists;
@@ -46,6 +47,7 @@ import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.impl.ShuffleReadClientImpl;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.compression.Codec;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.ChecksumUtils;
@@ -321,11 +323,12 @@ public class RssShuffleDataIteratorTest extends
AbstractRssReaderTest {
RssShuffleDataIterator rssShuffleDataIterator =
getDataIterator(
basePath, blockIdBitmap, taskIdBitmap, Lists.newArrayList(ssi1,
ssi2), compress);
- Object codec = FieldUtils.readField(rssShuffleDataIterator, "codec", true);
+ Optional<Codec> codec =
+ (Optional<Codec>) FieldUtils.readField(rssShuffleDataIterator,
"codec", true);
if (compress) {
- Assertions.assertNotNull(codec);
+ Assertions.assertTrue(codec.isPresent());
} else {
- Assertions.assertNull(codec);
+ Assertions.assertFalse(codec.isPresent());
}
validateResult(rssShuffleDataIterator, expectedData, 20);
diff --git
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
index 49ebeef25..19c9f6d10 100644
---
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
+++
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
@@ -21,6 +21,7 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
+import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
@@ -46,6 +47,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.compression.Codec;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.util.BlockIdLayout;
@@ -122,11 +124,11 @@ public class WriteBufferManagerTest {
conf.set(RssSparkConfig.SPARK_SHUFFLE_COMPRESS_KEY,
String.valueOf(false));
}
WriteBufferManager wbm = createManager(conf);
- Object codec = FieldUtils.readField(wbm, "codec", true);
+ Optional<Codec> codec = (Optional<Codec>) FieldUtils.readField(wbm,
"codec", true);
if (compress) {
- Assertions.assertNotNull(codec);
+ Assertions.assertTrue(codec.isPresent());
} else {
- Assertions.assertNull(codec);
+ Assertions.assertFalse(codec.isPresent());
}
wbm.setShuffleWriteMetrics(new ShuffleWriteMetrics());
String testKey = "Key";
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcher.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcher.java
index 5ff38333c..7edb4f443 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcher.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcher.java
@@ -20,6 +20,7 @@ package org.apache.tez.runtime.library.common.shuffle.impl;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Map;
+import java.util.Optional;
import com.google.common.annotations.VisibleForTesting;
import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
@@ -36,6 +37,7 @@ import
org.apache.uniffle.client.response.CompressedShuffleBlock;
import org.apache.uniffle.common.compression.Codec;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.util.ByteBufferUtils;
public class RssTezFetcher {
private static final Logger LOG =
LoggerFactory.getLogger(RssTezFetcher.class);
@@ -62,7 +64,7 @@ public class RssTezFetcher {
private long startWait;
private int waitCount = 0;
private byte[] uncompressedData = null;
- private Codec codec;
+ private Optional<Codec> codec;
RssTezFetcher(
FetcherCallback fetcherCallback,
@@ -109,14 +111,19 @@ public class RssTezFetcher {
// uncompress the block
if (!hasPendingData && compressedData != null) {
- final long startDecompress = System.currentTimeMillis();
- int uncompressedLen = compressedBlock.getUncompressLength();
- ByteBuffer decompressedBuffer = ByteBuffer.allocate(uncompressedLen);
- codec.decompress(compressedData, uncompressedLen, decompressedBuffer, 0);
- uncompressedData = decompressedBuffer.array();
- unCompressionLength += compressedBlock.getUncompressLength();
- long decompressDuration = System.currentTimeMillis() - startDecompress;
- decompressTime += decompressDuration;
+ if (codec.isPresent()) {
+ final long startDecompress = System.currentTimeMillis();
+ int uncompressedLen = compressedBlock.getUncompressLength();
+ ByteBuffer decompressedBuffer = ByteBuffer.allocate(uncompressedLen);
+ codec.get().decompress(compressedData, uncompressedLen,
decompressedBuffer, 0);
+ uncompressedData = decompressedBuffer.array();
+ unCompressionLength += compressedBlock.getUncompressLength();
+ long decompressDuration = System.currentTimeMillis() - startDecompress;
+ decompressTime += decompressDuration;
+ } else {
+ uncompressedData = ByteBufferUtils.bufferToArray(compressedData);
+ unCompressionLength += uncompressedData.length;
+ }
}
if (uncompressedData != null) {
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssTezShuffleDataFetcher.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssTezShuffleDataFetcher.java
index 992f509d7..06de81013 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssTezShuffleDataFetcher.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssTezShuffleDataFetcher.java
@@ -19,6 +19,7 @@ package
org.apache.tez.runtime.library.common.shuffle.orderedgrouped;
import java.io.IOException;
import java.nio.ByteBuffer;
+import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import com.google.common.annotations.VisibleForTesting;
@@ -34,6 +35,7 @@ import
org.apache.uniffle.client.response.CompressedShuffleBlock;
import org.apache.uniffle.common.compression.Codec;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.util.ByteBufferUtils;
public class RssTezShuffleDataFetcher extends CallableWithNdc<Void> {
private static final Logger LOG =
LoggerFactory.getLogger(RssTezShuffleDataFetcher.class);
@@ -70,7 +72,7 @@ public class RssTezShuffleDataFetcher extends
CallableWithNdc<Void> {
private long startWait;
private int waitCount = 0;
private byte[] uncompressedData = null;
- private final Codec rssCodec;
+ private final Optional<Codec> rssCodec;
private Integer partitionId;
private final ExceptionReporter exceptionReporter;
@@ -151,14 +153,19 @@ public class RssTezShuffleDataFetcher extends
CallableWithNdc<Void> {
// uncompress the block
if (!hasPendingData && compressedData != null) {
- final long startDecompress = System.currentTimeMillis();
- int uncompressedLen = compressedBlock.getUncompressLength();
- ByteBuffer decompressedBuffer = ByteBuffer.allocate(uncompressedLen);
- rssCodec.decompress(compressedData, uncompressedLen, decompressedBuffer,
0);
- uncompressedData = decompressedBuffer.array();
- unCompressionLength += compressedBlock.getUncompressLength();
- long decompressDuration = System.currentTimeMillis() - startDecompress;
- decompressTime += decompressDuration;
+ if (rssCodec.isPresent()) {
+ final long startDecompress = System.currentTimeMillis();
+ int uncompressedLen = compressedBlock.getUncompressLength();
+ ByteBuffer decompressedBuffer = ByteBuffer.allocate(uncompressedLen);
+ rssCodec.get().decompress(compressedData, uncompressedLen,
decompressedBuffer, 0);
+ uncompressedData = decompressedBuffer.array();
+ unCompressionLength += compressedBlock.getUncompressLength();
+ long decompressDuration = System.currentTimeMillis() - startDecompress;
+ decompressTime += decompressDuration;
+ } else {
+ uncompressedData = ByteBufferUtils.bufferToArray(compressedData);
+ unCompressionLength += uncompressedData.length;
+ }
}
if (uncompressedData != null) {
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
index 53cfeba45..93735efa4 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
@@ -23,6 +23,7 @@ import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@@ -84,7 +85,7 @@ public class WriteBufferManager<K, V> {
private final double memoryThreshold;
private final double sendThreshold;
private final int batch;
- private final Codec codec;
+ private final Optional<Codec> codec;
private final Map<Integer, List<ShuffleServerInfo>> partitionToServers;
private final Set<Long> allBlockIds = Sets.newConcurrentHashSet();
// server -> partitionId -> blockIds
@@ -370,7 +371,7 @@ public class WriteBufferManager<K, V> {
final int uncompressLength = data.length;
long start = System.currentTimeMillis();
- final byte[] compressed = codec.compress(data);
+ final byte[] compressed = codec.map(c -> c.compress(data)).orElse(data);
final long crc32 = ChecksumUtils.getCrc32(compressed);
compressTime += System.currentTimeMillis() - start;
final long blockId =
diff --git
a/common/src/main/java/org/apache/uniffle/common/compression/Codec.java
b/common/src/main/java/org/apache/uniffle/common/compression/Codec.java
index b2ac5f0bb..72c69dc06 100644
--- a/common/src/main/java/org/apache/uniffle/common/compression/Codec.java
+++ b/common/src/main/java/org/apache/uniffle/common/compression/Codec.java
@@ -18,6 +18,7 @@
package org.apache.uniffle.common.compression;
import java.nio.ByteBuffer;
+import java.util.Optional;
import org.apache.uniffle.common.config.RssConf;
@@ -26,18 +27,20 @@ import static
org.apache.uniffle.common.config.RssClientConf.ZSTD_COMPRESSION_LE
public abstract class Codec {
- public static Codec newInstance(RssConf rssConf) {
+ public static Optional<Codec> newInstance(RssConf rssConf) {
Type type = rssConf.get(COMPRESSION_TYPE);
switch (type) {
+ case NONE:
+ return Optional.empty();
case ZSTD:
- return ZstdCodec.getInstance(rssConf.get(ZSTD_COMPRESSION_LEVEL));
+ return
Optional.of(ZstdCodec.getInstance(rssConf.get(ZSTD_COMPRESSION_LEVEL)));
case SNAPPY:
- return SnappyCodec.getInstance();
+ return Optional.of(SnappyCodec.getInstance());
case NOOP:
- return NoOpCodec.getInstance();
+ return Optional.of(NoOpCodec.getInstance());
case LZ4:
default:
- return Lz4Codec.getInstance();
+ return Optional.of(Lz4Codec.getInstance());
}
}
@@ -72,5 +75,6 @@ public abstract class Codec {
ZSTD,
NOOP,
SNAPPY,
+ NONE,
}
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/util/ByteBufferUtils.java
b/common/src/main/java/org/apache/uniffle/common/util/ByteBufferUtils.java
new file mode 100644
index 000000000..f32f3d5a8
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/util/ByteBufferUtils.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.util;
+
+import java.nio.ByteBuffer;
+
+public class ByteBufferUtils {
+
+ public static byte[] bufferToArray(ByteBuffer buffer) {
+ if (buffer.hasArray()
+ && buffer.arrayOffset() == 0
+ && buffer.array().length == buffer.remaining()) {
+ return buffer.array();
+ } else {
+ byte[] bytes = new byte[buffer.remaining()];
+ buffer.get(bytes);
+ return bytes;
+ }
+ }
+}
diff --git
a/common/src/test/java/org/apache/uniffle/common/compression/CompressionTest.java
b/common/src/test/java/org/apache/uniffle/common/compression/CompressionTest.java
index 629ad4728..ac5af5aa7 100644
---
a/common/src/test/java/org/apache/uniffle/common/compression/CompressionTest.java
+++
b/common/src/test/java/org/apache/uniffle/common/compression/CompressionTest.java
@@ -56,7 +56,7 @@ public class CompressionTest {
conf.set(COMPRESSION_TYPE, type);
// case1: heap bytebuffer
- Codec codec = Codec.newInstance(conf);
+ Codec codec = Codec.newInstance(conf).get();
byte[] compressed = codec.compress(data);
ByteBuffer dest = ByteBuffer.allocate(size);