This is an automated email from the ASF dual-hosted git repository.
zhifgli pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 01def93f Supports ZSTD (#254)
01def93f is described below
commit 01def93fff1a40a83676d07821754ba0d91d65f2
Author: Junfan Zhang <[email protected]>
AuthorDate: Wed Oct 26 19:05:03 2022 +0800
Supports ZSTD (#254)
### What changes were proposed in this pull request?
1. Introduce the ZSTD compression
2. Introduce the abstract interface of codec
3. Recycle the buffer to optimize the performance
### Why are the changes needed?
ZSTD has a good tradeoff between compression ratio and de/compress speed.
For reducing the shuffle-data stored size, it's necessary to support this
compression algorithm.
### Does this PR introduce _any_ user-facing change?
Yes
### How was this patch tested?
Manual tests and UTs
---
client-mr/pom.xml | 4 +
.../hadoop/mapred/RssMapOutputCollector.java | 3 +-
.../hadoop/mapred/SortWriteBufferManager.java | 12 ++-
.../org/apache/hadoop/mapreduce/RssMRConfig.java | 16 ++++
.../hadoop/mapreduce/task/reduce/RssFetcher.java | 27 ++++--
.../hadoop/mapreduce/task/reduce/RssShuffle.java | 2 +-
.../hadoop/mapred/SortWriteBufferManagerTest.java | 13 ++-
.../hadoop/mapreduce/task/reduce/FetcherTest.java | 26 ++++--
.../org/apache/spark/shuffle/RssSparkConfig.java | 16 ++++
.../shuffle/reader/RssShuffleDataIterator.java | 41 +++++----
.../spark/shuffle/writer/WriteBufferManager.java | 10 ++-
.../shuffle/reader/AbstractRssReaderTest.java | 5 +-
.../shuffle/reader/RssShuffleDataIteratorTest.java | 6 +-
.../shuffle/writer/WriteBufferManagerTest.java | 3 +-
.../apache/spark/shuffle/RssShuffleManager.java | 14 ++-
.../spark/shuffle/reader/RssShuffleReader.java | 8 +-
.../spark/shuffle/reader/RssShuffleReaderTest.java | 3 +-
.../spark/shuffle/writer/RssShuffleWriterTest.java | 11 ++-
.../apache/spark/shuffle/RssShuffleManager.java | 5 +-
.../spark/shuffle/reader/RssShuffleReader.java | 8 +-
.../spark/shuffle/reader/RssShuffleReaderTest.java | 8 +-
.../spark/shuffle/writer/RssShuffleWriterTest.java | 27 ++----
common/pom.xml | 5 ++
.../org/apache/uniffle/common/RssShuffleUtils.java | 36 --------
.../apache/uniffle/common/compression/Codec.java | 51 +++++++++++
.../uniffle/common/compression/Lz4Codec.java | 41 +++++++++
.../uniffle/common/compression/NoOpCodec.java | 35 ++++++++
.../uniffle/common/compression/ZstdCodec.java | 67 +++++++++++++++
.../uniffle/common/config/RssClientConf.java | 38 +++++++++
.../apache/uniffle/common/RssShuffleUtilsTest.java | 99 ----------------------
.../common/compression/CompressionTest.java | 83 ++++++++++++++++++
docs/client_guide.md | 2 +
.../test/RepartitionWithLocalFileRssTest.java | 38 +++++++++
.../uniffle/test/SparkIntegrationTestBase.java | 2 +-
pom.xml | 7 ++
35 files changed, 549 insertions(+), 223 deletions(-)
diff --git a/client-mr/pom.xml b/client-mr/pom.xml
index 2c2332f4..b91c3160 100644
--- a/client-mr/pom.xml
+++ b/client-mr/pom.xml
@@ -105,6 +105,10 @@
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>com.github.luben</groupId>
+ <artifactId>zstd-jni</artifactId>
+ </dependency>
</dependencies>
<build>
diff --git
a/client-mr/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
b/client-mr/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
index 308a560c..c9cb553f 100644
---
a/client-mr/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
+++
b/client-mr/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
@@ -130,7 +130,8 @@ public class RssMapOutputCollector<K extends Object, V
extends Object>
isMemoryShuffleEnabled(storageType),
sendThreadNum,
sendThreshold,
- maxBufferSize);
+ maxBufferSize,
+ RssMRConfig.toRssConf(rssJobConf));
}
private Map<Integer, List<ShuffleServerInfo>> createAssignmentMap(JobConf
jobConf) {
diff --git
a/client-mr/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
b/client-mr/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
index aa4da547..36ade47e 100644
---
a/client-mr/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
+++
b/client-mr/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
@@ -43,9 +43,10 @@ import org.slf4j.LoggerFactory;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.response.SendShuffleDataResult;
-import org.apache.uniffle.common.RssShuffleUtils;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.ChecksumUtils;
import org.apache.uniffle.common.util.ThreadUtils;
@@ -90,6 +91,8 @@ public class SortWriteBufferManager<K, V> {
private long sortTime = 0;
private final long maxBufferSize;
private final ExecutorService sendExecutorService;
+ private final RssConf rssConf;
+ private final Codec codec;
public SortWriteBufferManager(
long maxMemSize,
@@ -114,7 +117,8 @@ public class SortWriteBufferManager<K, V> {
boolean isMemoryShuffleEnabled,
int sendThreadNum,
double sendThreshold,
- long maxBufferSize) {
+ long maxBufferSize,
+ RssConf rssConf) {
this.maxMemSize = maxMemSize;
this.taskAttemptId = taskAttemptId;
this.batch = batch;
@@ -140,6 +144,8 @@ public class SortWriteBufferManager<K, V> {
this.sendExecutorService = Executors.newFixedThreadPool(
sendThreadNum,
ThreadUtils.getThreadFactory("send-thread-%d"));
+ this.rssConf = rssConf;
+ this.codec = Codec.newInstance(rssConf);
}
// todo: Single Buffer should also have its size limit
@@ -309,7 +315,7 @@ public class SortWriteBufferManager<K, V> {
int partitionId = wb.getPartitionId();
final int uncompressLength = data.length;
long start = System.currentTimeMillis();
- final byte[] compressed = RssShuffleUtils.compressData(data);
+ final byte[] compressed = codec.compress(data);
final long crc32 = ChecksumUtils.getCrc32(compressed);
compressTime += System.currentTimeMillis() - start;
final long blockId = RssMRUtils.getBlockId((long)partitionId,
taskAttemptId, getNextSeqNo(partitionId));
diff --git
a/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java
b/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java
index eb518162..d89b4f12 100644
--- a/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java
+++ b/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java
@@ -17,11 +17,14 @@
package org.apache.hadoop.mapreduce;
+import java.util.Map;
import java.util.Set;
import com.google.common.collect.ImmutableSet;
+import org.apache.hadoop.mapred.JobConf;
import org.apache.uniffle.client.util.RssClientConfig;
+import org.apache.uniffle.common.config.RssConf;
public class RssMRConfig {
@@ -164,4 +167,17 @@ public class RssMRConfig {
public static final Set<String> RSS_MANDATORY_CLUSTER_CONF =
ImmutableSet.of(RSS_STORAGE_TYPE, RSS_REMOTE_STORAGE_PATH);
+
+ public static RssConf toRssConf(JobConf jobConf) {
+ RssConf rssConf = new RssConf();
+ for (Map.Entry<String, String> entry : jobConf) {
+ String key = entry.getKey();
+ if (!key.startsWith(MR_RSS_CONFIG_PREFIX)) {
+ continue;
+ }
+ key = key.substring(MR_RSS_CONFIG_PREFIX.length());
+ rssConf.setString(key, entry.getValue());
+ }
+ return rssConf;
+ }
}
diff --git
a/client-mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssFetcher.java
b/client-mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssFetcher.java
index 128bfb9f..8e0859cc 100644
---
a/client-mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssFetcher.java
+++
b/client-mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssFetcher.java
@@ -35,7 +35,8 @@ import org.apache.hadoop.util.Progress;
import org.apache.uniffle.client.api.ShuffleReadClient;
import org.apache.uniffle.client.response.CompressedShuffleBlock;
-import org.apache.uniffle.common.RssShuffleUtils;
+import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.ByteUnit;
@@ -84,14 +85,17 @@ public class RssFetcher<K,V> {
private long startWait;
private int waitCount = 0;
private byte[] uncompressedData = null;
+ private RssConf rssConf;
+ private Codec codec;
RssFetcher(JobConf job, TaskAttemptID reduceId,
- TaskStatus status,
- MergeManager<K,V> merger,
- Progress progress,
- Reporter reporter, ShuffleClientMetrics metrics,
- ShuffleReadClient shuffleReadClient,
- long totalBlockCount) {
+ TaskStatus status,
+ MergeManager<K, V> merger,
+ Progress progress,
+ Reporter reporter, ShuffleClientMetrics metrics,
+ ShuffleReadClient shuffleReadClient,
+ long totalBlockCount,
+ RssConf rssConf) {
this.jobConf = job;
this.reporter = reporter;
this.status = status;
@@ -114,6 +118,9 @@ public class RssFetcher<K,V> {
this.shuffleReadClient = shuffleReadClient;
this.totalBlockCount = totalBlockCount;
+
+ this.rssConf = rssConf;
+ this.codec = Codec.newInstance(rssConf);
}
public void fetchAllRssBlocks() throws IOException, InterruptedException {
@@ -150,8 +157,10 @@ public class RssFetcher<K,V> {
// uncompress the block
if (!hasPendingData && compressedData != null) {
final long startDecompress = System.currentTimeMillis();
- uncompressedData = RssShuffleUtils.decompressData(
- compressedData, compressedBlock.getUncompressLength(),
false).array();
+ int uncompressedLen = compressedBlock.getUncompressLength();
+ ByteBuffer decompressedBuffer = ByteBuffer.allocate(uncompressedLen);
+ codec.decompress(compressedData, uncompressedLen, decompressedBuffer, 0);
+ uncompressedData = decompressedBuffer.array();
unCompressionLength += compressedBlock.getUncompressLength();
long decompressDuration = System.currentTimeMillis() - startDecompress;
decompressTime += decompressDuration;
diff --git
a/client-mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java
b/client-mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java
index 1d30df96..e5af9795 100644
---
a/client-mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java
+++
b/client-mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java
@@ -197,7 +197,7 @@ public class RssShuffle<K, V> implements
ShuffleConsumerPlugin<K, V>, ExceptionR
readerJobConf, new MRIdHelper());
ShuffleReadClient shuffleReadClient =
ShuffleClientFactory.getInstance().createShuffleReadClient(request);
RssFetcher fetcher = new RssFetcher(mrJobConf, reduceId, taskStatus,
merger, copyPhase, reporter, metrics,
- shuffleReadClient, blockIdBitmap.getLongCardinality());
+ shuffleReadClient, blockIdBitmap.getLongCardinality(),
RssMRConfig.toRssConf(rssJobConf));
fetcher.fetchAllRssBlocks();
LOG.info("In reduce: " + reduceId
+ ", Rss MR client fetches blocks from RSS server successfully");
diff --git
a/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
b/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
index 029b1e0e..305a9dcb 100644
---
a/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
+++
b/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
@@ -38,6 +38,7 @@ import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleAssignmentsInfo;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -79,7 +80,8 @@ public class SortWriteBufferManagerTest {
true,
5,
0.2f,
- 1024000L);
+ 1024000L,
+ new RssConf());
Random random = new Random();
for (int i = 0; i < 1000; i++) {
byte[] key = new byte[20];
@@ -128,7 +130,8 @@ public class SortWriteBufferManagerTest {
true,
5,
0.2f,
- 1024000L);
+ 1024000L,
+ new RssConf());
byte[] key = new byte[20];
byte[] value = new byte[1024];
random.nextBytes(key);
@@ -176,7 +179,8 @@ public class SortWriteBufferManagerTest {
true,
5,
0.2f,
- 100L);
+ 100L,
+ new RssConf());
Random random = new Random();
for (int i = 0; i < 1000; i++) {
byte[] key = new byte[20];
@@ -223,7 +227,8 @@ public class SortWriteBufferManagerTest {
true,
5,
0.2f,
- 1024000L);
+ 1024000L,
+ new RssConf());
Random random = new Random();
for (int i = 0; i < 1000; i++) {
byte[] key = new byte[20];
diff --git
a/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
b/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
index ec630e24..b5404e59 100644
---
a/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
+++
b/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
@@ -65,10 +65,12 @@ import
org.apache.uniffle.client.response.CompressedShuffleBlock;
import org.apache.uniffle.client.response.SendShuffleDataResult;
import org.apache.uniffle.common.PartitionRange;
import org.apache.uniffle.common.RemoteStorageInfo;
-import org.apache.uniffle.common.RssShuffleUtils;
import org.apache.uniffle.common.ShuffleAssignmentsInfo;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.compression.Lz4Codec;
+import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -88,6 +90,8 @@ public class FetcherTest {
static List<byte[]> data;
static MergeManagerImpl<Text, Text> merger;
+ static Codec codec = new Lz4Codec();
+
@Test
public void writeAndReadDataTestWithRss() throws Throwable {
fs = FileSystem.getLocal(conf);
@@ -97,7 +101,7 @@ public class FetcherTest {
null, null, new Progress(), new MROutputFiles());
ShuffleReadClient shuffleReadClient = new MockedShuffleReadClient(data);
RssFetcher fetcher = new RssFetcher(jobConf, reduceId1, taskStatus,
merger, new Progress(),
- reporter, metrics, shuffleReadClient, 3);
+ reporter, metrics, shuffleReadClient, 3, new RssConf());
fetcher.fetchAllRssBlocks();
@@ -128,7 +132,7 @@ public class FetcherTest {
null, null, new Progress(), new MROutputFiles());
ShuffleReadClient shuffleReadClient = new MockedShuffleReadClient(data);
RssFetcher fetcher = new RssFetcher(jobConf, reduceId1, taskStatus,
merger, new Progress(),
- reporter, metrics, shuffleReadClient, 3);
+ reporter, metrics, shuffleReadClient, 3, new RssConf());
fetcher.fetchAllRssBlocks();
@@ -161,7 +165,7 @@ public class FetcherTest {
null, null, new Progress(), new MROutputFiles(), expectedFails);
ShuffleReadClient shuffleReadClient = new MockedShuffleReadClient(data);
RssFetcher fetcher = new RssFetcher(jobConf, reduceId1, taskStatus,
merger, new Progress(),
- reporter, metrics, shuffleReadClient, 3);
+ reporter, metrics, shuffleReadClient, 3, new RssConf());
fetcher.fetchAllRssBlocks();
RawKeyValueIterator iterator = merger.close();
@@ -276,7 +280,8 @@ public class FetcherTest {
true,
5,
0.2f,
- 1024000L);
+ 1024000L,
+ new RssConf());
for (String key : keysToValues.keySet()) {
String value = keysToValues.get(key);
@@ -357,7 +362,14 @@ public class FetcherTest {
successBlockIds.add(blockInfo.getBlockId());
}
shuffleBlockInfoList.forEach(block -> {
- data.add(RssShuffleUtils.decompressData(block.getData(),
block.getUncompressLength()));
+ ByteBuffer uncompressedBuffer =
ByteBuffer.allocate(block.getUncompressLength());
+ codec.decompress(
+ ByteBuffer.wrap(block.getData()),
+ block.getUncompressLength(),
+ uncompressedBuffer,
+ 0
+ );
+ data.add(uncompressedBuffer.array());
});
return new SendShuffleDataResult(successBlockIds, Sets.newHashSet());
}
@@ -440,7 +452,7 @@ public class FetcherTest {
MockedShuffleReadClient(List<byte[]> data) {
this.blocks = new LinkedList<>();
data.forEach(bytes -> {
- byte[] compressed = RssShuffleUtils.compressData(bytes);
+ byte[] compressed = codec.compress(bytes);
blocks.add(new CompressedShuffleBlock(ByteBuffer.wrap(compressed),
bytes.length));
});
}
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index 5f39eb5d..71b4c283 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -20,13 +20,16 @@ package org.apache.spark.shuffle;
import java.util.Set;
import com.google.common.collect.ImmutableSet;
+import org.apache.spark.SparkConf;
import org.apache.spark.internal.config.ConfigBuilder;
import org.apache.spark.internal.config.ConfigEntry;
import org.apache.spark.internal.config.TypedConfigBuilder;
+import scala.Tuple2;
import scala.runtime.AbstractFunction1;
import org.apache.uniffle.client.util.RssClientConfig;
import org.apache.uniffle.common.config.ConfigUtils;
+import org.apache.uniffle.common.config.RssConf;
public class RssSparkConfig {
@@ -286,4 +289,17 @@ public class RssSparkConfig {
public static TypedConfigBuilder<String> createStringBuilder(ConfigBuilder
builder) {
return builder.stringConf();
}
+
+ public static RssConf toRssConf(SparkConf sparkConf) {
+ RssConf rssConf = new RssConf();
+ for (Tuple2<String, String> tuple : sparkConf.getAll()) {
+ String key = tuple._1;
+ if (!key.startsWith(SPARK_RSS_CONFIG_PREFIX)) {
+ continue;
+ }
+ key = key.substring(SPARK_RSS_CONFIG_PREFIX.length());
+ rssConf.setString(key, tuple._2);
+ }
+ return rssConf;
+ }
}
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
index 23e03641..7ba3e066 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
@@ -28,6 +28,7 @@ import org.apache.spark.executor.ShuffleReadMetrics;
import org.apache.spark.serializer.DeserializationStream;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.RssSparkConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Product2;
@@ -38,8 +39,9 @@ import scala.runtime.BoxedUnit;
import org.apache.uniffle.client.api.ShuffleReadClient;
import org.apache.uniffle.client.response.CompressedShuffleBlock;
-import org.apache.uniffle.common.RssShuffleUtils;
-import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.client.util.RssClientConfig;
+import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.config.RssConf;
public class RssShuffleDataIterator<K, C> extends AbstractIterator<Product2<K,
C>> {
@@ -57,19 +59,29 @@ public class RssShuffleDataIterator<K, C> extends
AbstractIterator<Product2<K, C
private ByteBufInputStream byteBufInputStream = null;
private long unCompressionLength = 0;
private ByteBuffer uncompressedData;
+ private Codec codec;
public RssShuffleDataIterator(
Serializer serializer,
ShuffleReadClient shuffleReadClient,
- ShuffleReadMetrics shuffleReadMetrics) {
+ ShuffleReadMetrics shuffleReadMetrics,
+ RssConf rssConf) {
this.serializerInstance = serializer.newInstance();
this.shuffleReadClient = shuffleReadClient;
this.shuffleReadMetrics = shuffleReadMetrics;
+ this.codec = Codec.newInstance(rssConf);
+ // todo: support off-heap bytebuffer
+ this.uncompressedData = ByteBuffer.allocate(
+ (int) rssConf.getSizeAsBytes(
+ RssClientConfig.RSS_WRITER_BUFFER_SIZE,
+ RssSparkConfig.RSS_WRITER_BUFFER_SIZE.defaultValueString()
+ )
+ );
}
- public Iterator<Tuple2<Object, Object>> createKVIterator(ByteBuffer data) {
+ public Iterator<Tuple2<Object, Object>> createKVIterator(ByteBuffer data,
int size) {
clearDeserializationStream();
- byteBufInputStream = new ByteBufInputStream(Unpooled.wrappedBuffer(data),
true);
+ byteBufInputStream = new
ByteBufInputStream(Unpooled.wrappedBuffer(data.array(), 0, size), true);
deserializationStream =
serializerInstance.deserializeStream(byteBufInputStream);
return deserializationStream.asKeyValueIterator();
}
@@ -109,24 +121,20 @@ public class RssShuffleDataIterator<K, C> extends
AbstractIterator<Product2<K, C
shuffleReadMetrics.incFetchWaitTime(fetchDuration);
if (compressedData != null) {
shuffleReadMetrics.incRemoteBytesRead(compressedData.limit() -
compressedData.position());
- // Directbytebuffers are not collected in time will cause executor
easy
- // be killed by cluster managers(such as YARN) for using too much
offheap memory
- if (uncompressedData != null && uncompressedData.isDirect()) {
- try {
- RssShuffleUtils.destroyDirectByteBuffer(uncompressedData);
- } catch (Exception e) {
- throw new RssException("Destroy DirectByteBuffer failed!", e);
- }
+
+ int uncompressedLen = compressedBlock.getUncompressLength();
+ if (uncompressedData == null || uncompressedData.capacity() <
uncompressedLen) {
+ uncompressedData = ByteBuffer.allocate(uncompressedLen);
}
+ uncompressedData.clear();
long startDecompress = System.currentTimeMillis();
- uncompressedData = RssShuffleUtils.decompressData(
- compressedData, compressedBlock.getUncompressLength());
+ codec.decompress(compressedData, uncompressedLen, uncompressedData, 0);
unCompressionLength += compressedBlock.getUncompressLength();
long decompressDuration = System.currentTimeMillis() - startDecompress;
decompressTime += decompressDuration;
// create new iterator for shuffle data
long startSerialization = System.currentTimeMillis();
- recordsIterator = createKVIterator(uncompressedData);
+ recordsIterator = createKVIterator(uncompressedData, uncompressedLen);
long serializationDuration = System.currentTimeMillis() -
startSerialization;
readTime += fetchDuration;
serializeTime += serializationDuration;
@@ -155,6 +163,7 @@ public class RssShuffleDataIterator<K, C> extends
AbstractIterator<Product2<K, C
shuffleReadClient.close();
}
shuffleReadClient = null;
+ uncompressedData = null;
return BoxedUnit.UNIT;
}
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index ffb6000b..5c10ac67 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -37,9 +37,10 @@ import org.slf4j.LoggerFactory;
import scala.reflect.ClassTag$;
import org.apache.uniffle.client.util.ClientUtils;
-import org.apache.uniffle.common.RssShuffleUtils;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.ChecksumUtils;
@@ -77,6 +78,7 @@ public class WriteBufferManager extends MemoryConsumer {
private long uncompressedDataLen = 0;
private long requireMemoryInterval;
private int requireMemoryRetryMax;
+ private Codec codec;
public WriteBufferManager(
int shuffleId,
@@ -85,7 +87,8 @@ public class WriteBufferManager extends MemoryConsumer {
Serializer serializer,
Map<Integer, List<ShuffleServerInfo>> partitionToServers,
TaskMemoryManager taskMemoryManager,
- ShuffleWriteMetrics shuffleWriteMetrics) {
+ ShuffleWriteMetrics shuffleWriteMetrics,
+ RssConf rssConf) {
super(taskMemoryManager, taskMemoryManager.pageSizeBytes(),
MemoryMode.ON_HEAP);
this.bufferSize = bufferManagerOptions.getBufferSize();
this.spillSize = bufferManagerOptions.getBufferSpillThreshold();
@@ -102,6 +105,7 @@ public class WriteBufferManager extends MemoryConsumer {
this.requireMemoryRetryMax =
bufferManagerOptions.getRequireMemoryRetryMax();
this.arrayOutputStream = new
WrappedByteArrayOutputStream(serializerBufferSize);
this.serializeStream = instance.serializeStream(arrayOutputStream);
+ this.codec = Codec.newInstance(rssConf);
}
public List<ShuffleBlockInfo> addRecord(int partitionId, Object key, Object
value) {
@@ -170,7 +174,7 @@ public class WriteBufferManager extends MemoryConsumer {
byte[] data = wb.getData();
final int uncompressLength = data.length;
long start = System.currentTimeMillis();
- final byte[] compressed = RssShuffleUtils.compressData(data);
+ final byte[] compressed = codec.compress(data);
final long crc32 = ChecksumUtils.getCrc32(compressed);
compressTime += System.currentTimeMillis() - start;
final long blockId = ClientUtils.getBlockId(partitionId, taskAttemptId,
getNextSeqNo(partitionId));
diff --git
a/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/AbstractRssReaderTest.java
b/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/AbstractRssReaderTest.java
index 422c6d04..fd290cb4 100644
---
a/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/AbstractRssReaderTest.java
+++
b/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/AbstractRssReaderTest.java
@@ -34,8 +34,9 @@ import scala.collection.Iterator;
import scala.reflect.ClassTag$;
import org.apache.uniffle.client.util.ClientUtils;
-import org.apache.uniffle.common.RssShuffleUtils;
import org.apache.uniffle.common.ShufflePartitionedBlock;
+import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.util.ChecksumUtils;
import org.apache.uniffle.storage.HdfsTestBase;
import org.apache.uniffle.storage.handler.api.ShuffleWriteHandler;
@@ -90,7 +91,7 @@ public abstract class AbstractRssReaderTest extends
HdfsTestBase {
}
protected ShufflePartitionedBlock createShuffleBlock(byte[] data, long
blockId) {
- byte[] compressData = RssShuffleUtils.compressData(data);
+ byte[] compressData = Codec.newInstance(new RssConf()).compress(data);
long crc = ChecksumUtils.getCrc32(compressData);
return new ShufflePartitionedBlock(compressData.length, data.length, crc,
blockId, 0,
compressData);
diff --git
a/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java
b/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java
index f4f55c18..78c3375b 100644
---
a/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java
+++
b/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java
@@ -38,6 +38,7 @@ import org.apache.uniffle.client.api.ShuffleReadClient;
import org.apache.uniffle.client.impl.ShuffleReadClientImpl;
import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.client.util.DefaultIdHelper;
+import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.util.ChecksumUtils;
import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.storage.handler.impl.HdfsShuffleWriteHandler;
@@ -96,7 +97,7 @@ public class RssShuffleDataIteratorTest extends
AbstractRssReaderTest {
10, 10000, basePath, blockIdBitmap, taskIdBitmap, Lists.newArrayList(),
new Configuration(), new DefaultIdHelper());
return new RssShuffleDataIterator(KRYO_SERIALIZER, readClient,
- new ShuffleReadMetrics());
+ new ShuffleReadMetrics(), new RssConf());
}
@Test
@@ -119,7 +120,6 @@ public class RssShuffleDataIteratorTest extends
AbstractRssReaderTest {
validateResult(rssShuffleDataIterator, expectedData, 20);
assertEquals(20,
rssShuffleDataIterator.getShuffleReadMetrics().recordsRead());
- assertEquals(256,
rssShuffleDataIterator.getShuffleReadMetrics().remoteBytesRead());
assertTrue(rssShuffleDataIterator.getShuffleReadMetrics().fetchWaitTime()
> 0);
}
@@ -250,7 +250,7 @@ public class RssShuffleDataIteratorTest extends
AbstractRssReaderTest {
ShuffleReadClient mockClient = mock(ShuffleReadClient.class);
doNothing().when(mockClient).close();
RssShuffleDataIterator dataIterator =
- new RssShuffleDataIterator(KRYO_SERIALIZER, mockClient, new
ShuffleReadMetrics());
+ new RssShuffleDataIterator(KRYO_SERIALIZER, mockClient, new
ShuffleReadMetrics(), new RssConf());
dataIterator.cleanup();
verify(mockClient, times(1)).close();
}
diff --git
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
index 665f5d2d..3fc20398 100644
---
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
+++
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
@@ -29,6 +29,7 @@ import org.apache.spark.shuffle.RssSparkConfig;
import org.junit.jupiter.api.Test;
import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.config.RssConf;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -47,7 +48,7 @@ public class WriteBufferManagerTest {
BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
WriteBufferManager wbm = new WriteBufferManager(
0, 0, bufferOptions, kryoSerializer,
- Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics());
+ Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics(),
new RssConf());
WriteBufferManager spyManager = spy(wbm);
doReturn(512L).when(spyManager).acquireMemory(anyLong());
return spyManager;
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 8f076040..26022a54 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -305,9 +305,15 @@ public class RssShuffleManager implements ShuffleManager {
BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
ShuffleWriteMetrics writeMetrics =
context.taskMetrics().shuffleWriteMetrics();
WriteBufferManager bufferManager = new WriteBufferManager(
- shuffleId, context.taskAttemptId(), bufferOptions,
rssHandle.getDependency().serializer(),
- rssHandle.getPartitionToServers(), context.taskMemoryManager(),
- writeMetrics);
+ shuffleId,
+ context.taskAttemptId(),
+ bufferOptions,
+ rssHandle.getDependency().serializer(),
+ rssHandle.getPartitionToServers(),
+ context.taskMemoryManager(),
+ writeMetrics,
+ RssSparkConfig.toRssConf(sparkConf)
+ );
taskToBufferManager.put(taskId, bufferManager);
return new RssShuffleWriter(rssHandle.getAppId(), shuffleId, taskId,
context.taskAttemptId(), bufferManager,
@@ -360,7 +366,7 @@ public class RssShuffleManager implements ShuffleManager {
rssShuffleHandle, shuffleRemoteStoragePath, indexReadLimit,
readerHadoopConf,
storageType, (int) readBufferSize, partitionNumPerRange,
partitionNum,
- blockIdBitmap, taskIdBitmap);
+ blockIdBitmap, taskIdBitmap, RssSparkConfig.toRssConf(sparkConf));
} else {
throw new RuntimeException("Unexpected ShuffleHandle:" +
handle.getClass().getName());
}
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index ef97bea3..a32ba226 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -44,6 +44,7 @@ import org.apache.uniffle.client.api.ShuffleReadClient;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.request.CreateShuffleReadClientRequest;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssConf;
public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
@@ -67,6 +68,7 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
private Roaring64NavigableMap taskIdBitmap;
private List<ShuffleServerInfo> shuffleServerInfoList;
private Configuration hadoopConf;
+ private RssConf rssConf;
public RssShuffleReader(
int startPartition,
@@ -81,7 +83,8 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
int partitionNumPerRange,
int partitionNum,
Roaring64NavigableMap blockIdBitmap,
- Roaring64NavigableMap taskIdBitmap) {
+ Roaring64NavigableMap taskIdBitmap,
+ RssConf rssConf) {
this.appId = rssShuffleHandle.getAppId();
this.startPartition = startPartition;
this.endPartition = endPartition;
@@ -101,6 +104,7 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
this.hadoopConf = hadoopConf;
this.shuffleServerInfoList =
(List<ShuffleServerInfo>)
(rssShuffleHandle.getPartitionToServers().get(startPartition));
+ this.rssConf = rssConf;
}
@Override
@@ -113,7 +117,7 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
ShuffleReadClient shuffleReadClient =
ShuffleClientFactory.getInstance().createShuffleReadClient(request);
RssShuffleDataIterator rssShuffleDataIterator = new
RssShuffleDataIterator<K, C>(
shuffleDependency.serializer(), shuffleReadClient,
- context.taskMetrics().shuffleReadMetrics());
+ context.taskMetrics().shuffleReadMetrics(), rssConf);
CompletionIterator completionIterator =
CompletionIterator$.MODULE$.apply(rssShuffleDataIterator, new
AbstractFunction0<BoxedUnit>() {
@Override
diff --git
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
index 473ce609..ce33f47c 100644
---
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
+++
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
@@ -31,6 +31,7 @@ import org.junit.jupiter.api.Test;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import scala.Option;
+import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.storage.handler.impl.HdfsShuffleWriteHandler;
import org.apache.uniffle.storage.util.StorageType;
@@ -73,7 +74,7 @@ public class RssShuffleReaderTest extends
AbstractRssReaderTest {
RssShuffleReader rssShuffleReaderSpy = spy(new RssShuffleReader<String,
String>(0, 1, contextMock,
handleMock, basePath, 1000, conf, StorageType.HDFS.name(),
- 1000, 2, 10, blockIdBitmap, taskIdBitmap));
+ 1000, 2, 10, blockIdBitmap, taskIdBitmap, new RssConf()));
validateResult(rssShuffleReaderSpy.read(), expectedData, 10);
}
diff --git
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index 084a731c..f71900ce 100644
---
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -47,6 +47,7 @@ import scala.collection.mutable.MutableList;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.storage.util.StorageType;
import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -91,7 +92,7 @@ public class RssShuffleWriterTest {
BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
WriteBufferManager bufferManager = new WriteBufferManager(
0, 0, bufferOptions, kryoSerializer,
- Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics());
+ Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics(),
new RssConf());
WriteBufferManager bufferManagerSpy = spy(bufferManager);
doReturn(1000000L).when(bufferManagerSpy).acquireMemory(anyLong());
@@ -197,7 +198,7 @@ public class RssShuffleWriterTest {
BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
WriteBufferManager bufferManager = new WriteBufferManager(
0, 0, bufferOptions, kryoSerializer,
- partitionToServers, mockTaskMemoryManager, shuffleWriteMetrics);
+ partitionToServers, mockTaskMemoryManager, shuffleWriteMetrics, new
RssConf());
WriteBufferManager bufferManagerSpy = spy(bufferManager);
doReturn(1000000L).when(bufferManagerSpy).acquireMemory(anyLong());
@@ -219,12 +220,14 @@ public class RssShuffleWriterTest {
assertTrue(rssShuffleWriterSpy.getShuffleWriteMetrics().shuffleWriteTime()
> 0);
assertEquals(6,
rssShuffleWriterSpy.getShuffleWriteMetrics().shuffleRecordsWritten());
- assertEquals(144,
rssShuffleWriterSpy.getShuffleWriteMetrics().shuffleBytesWritten());
+ assertEquals(
+ shuffleBlockInfos.stream().mapToInt(ShuffleBlockInfo::getLength).sum(),
+ rssShuffleWriterSpy.getShuffleWriteMetrics().shuffleBytesWritten()
+ );
assertEquals(6, shuffleBlockInfos.size());
for (ShuffleBlockInfo shuffleBlockInfo : shuffleBlockInfos) {
assertEquals(0, shuffleBlockInfo.getShuffleId());
- assertEquals(24, shuffleBlockInfo.getLength());
assertEquals(22, shuffleBlockInfo.getUncompressLength());
if (shuffleBlockInfo.getPartitionId() == 0) {
assertEquals(shuffleBlockInfo.getShuffleServerInfos(), ssi12);
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 41c2f4d7..ea29a4cd 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -331,7 +331,7 @@ public class RssShuffleManager implements ShuffleManager {
WriteBufferManager bufferManager = new WriteBufferManager(
shuffleId, context.taskAttemptId(), bufferOptions,
rssHandle.getDependency().serializer(),
rssHandle.getPartitionToServers(), context.taskMemoryManager(),
- writeMetrics);
+ writeMetrics, RssSparkConfig.toRssConf(sparkConf));
taskToBufferManager.put(taskId, bufferManager);
LOG.info("RssHandle appId {} shuffleId {} ", rssHandle.getAppId(),
rssHandle.getShuffleId());
return new RssShuffleWriter(rssHandle.getAppId(), shuffleId, taskId,
context.taskAttemptId(), bufferManager,
@@ -459,7 +459,8 @@ public class RssShuffleManager implements ShuffleManager {
partitionNum,
RssUtils.generatePartitionToBitmap(blockIdBitmap, startPartition,
endPartition),
taskIdBitmap,
- readMetrics);
+ readMetrics,
+ RssSparkConfig.toRssConf(sparkConf));
}
private Roaring64NavigableMap getExpectedTasksByExecutorId(
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index a565cfe4..2806ce82 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -50,6 +50,7 @@ import org.apache.uniffle.client.api.ShuffleReadClient;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.request.CreateShuffleReadClientRequest;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssConf;
public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
private static final Logger LOG =
LoggerFactory.getLogger(RssShuffleReader.class);
@@ -74,6 +75,7 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
private int mapStartIndex;
private int mapEndIndex;
private ShuffleReadMetrics readMetrics;
+ private RssConf rssConf;
public RssShuffleReader(
int startPartition,
@@ -90,7 +92,8 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
int partitionNum,
Map<Integer, Roaring64NavigableMap> partitionToExpectBlocks,
Roaring64NavigableMap taskIdBitmap,
- ShuffleReadMetrics readMetrics) {
+ ShuffleReadMetrics readMetrics,
+ RssConf rssConf) {
this.appId = rssShuffleHandle.getAppId();
this.startPartition = startPartition;
this.endPartition = endPartition;
@@ -111,6 +114,7 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
this.hadoopConf = hadoopConf;
this.readMetrics = readMetrics;
this.partitionToShuffleServers = rssShuffleHandle.getPartitionToServers();
+ this.rssConf = rssConf;
}
@Override
@@ -201,7 +205,7 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
ShuffleReadClient shuffleReadClient =
ShuffleClientFactory.getInstance().createShuffleReadClient(request);
RssShuffleDataIterator iterator = new RssShuffleDataIterator<K, C>(
shuffleDependency.serializer(), shuffleReadClient,
- readMetrics);
+ readMetrics, rssConf);
CompletionIterator<Product2<K, C>, RssShuffleDataIterator<K, C>>
completionIterator =
CompletionIterator$.MODULE$.apply(iterator, () ->
iterator.cleanup());
iterators.add(completionIterator);
diff --git
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
index 70938c88..5f8eceeb 100644
---
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
+++
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
@@ -32,6 +32,7 @@ import org.junit.jupiter.api.Test;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import scala.Option;
+import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.storage.handler.impl.HdfsShuffleWriteHandler;
import org.apache.uniffle.storage.util.StorageType;
@@ -93,7 +94,7 @@ public class RssShuffleReaderTest extends
AbstractRssReaderTest {
1,
partitionToExpectBlocks,
taskIdBitmap,
- new ShuffleReadMetrics()));
+ new ShuffleReadMetrics(), new RssConf()));
validateResult(rssShuffleReaderSpy.read(), expectedData, 10);
writeTestData(writeHandler1, 2, 4, expectedData,
@@ -114,7 +115,8 @@ public class RssShuffleReaderTest extends
AbstractRssReaderTest {
2,
partitionToExpectBlocks,
taskIdBitmap,
- new ShuffleReadMetrics()));
+ new ShuffleReadMetrics(), new RssConf())
+ );
validateResult(rssShuffleReaderSpy1.read(), expectedData, 18);
RssShuffleReader rssShuffleReaderSpy2 = spy(new RssShuffleReader<String,
String>(
@@ -132,7 +134,7 @@ public class RssShuffleReaderTest extends
AbstractRssReaderTest {
2,
partitionToExpectBlocks,
Roaring64NavigableMap.bitmapOf(),
- new ShuffleReadMetrics()));
+ new ShuffleReadMetrics(), new RssConf()));
validateResult(rssShuffleReaderSpy2.read(), Maps.newHashMap(), 0);
}
diff --git
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index 1b7afcd9..98ffc8a6 100644
---
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -49,6 +49,7 @@ import scala.collection.mutable.MutableList;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.storage.util.StorageType;
import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -98,7 +99,7 @@ public class RssShuffleWriterTest {
BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
WriteBufferManager bufferManager = new WriteBufferManager(
0, 0, bufferOptions, kryoSerializer,
- Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics());
+ Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics(),
new RssConf());
WriteBufferManager bufferManagerSpy = spy(bufferManager);
RssShuffleWriter rssShuffleWriter = new RssShuffleWriter("appId", 0,
"taskId", 1L,
@@ -206,7 +207,7 @@ public class RssShuffleWriterTest {
ShuffleWriteMetrics shuffleWriteMetrics = new ShuffleWriteMetrics();
WriteBufferManager bufferManager = new WriteBufferManager(
0, 0, bufferOptions, kryoSerializer,
- partitionToServers, mockTaskMemoryManager, shuffleWriteMetrics);
+ partitionToServers, mockTaskMemoryManager, shuffleWriteMetrics, new
RssConf());
WriteBufferManager bufferManagerSpy = spy(bufferManager);
RssShuffleWriter rssShuffleWriter = new RssShuffleWriter("appId", 0,
"taskId", 1L,
bufferManagerSpy, shuffleWriteMetrics, manager, conf,
mockShuffleWriteClient, mockHandle);
@@ -228,26 +229,14 @@ public class RssShuffleWriterTest {
assertTrue(shuffleWriteMetrics.writeTime() > 0);
assertEquals(6, shuffleWriteMetrics.recordsWritten());
- // Spark3 and Spark2 use different version lz4, their length is different
- // it can happen that 2 different platforms compress the same data
differently,
- // yet the decoded outcome remains identical to original.
- // https://github.com/lz4/lz4/issues/812
- if (TestUtils.isMacOnAppleSilicon()) {
- assertEquals(144, shuffleWriteMetrics.bytesWritten());
- } else {
- assertEquals(120, shuffleWriteMetrics.bytesWritten());
- }
+
+ assertEquals(
+ shuffleBlockInfos.stream().mapToInt(ShuffleBlockInfo::getLength).sum(),
+ shuffleWriteMetrics.bytesWritten()
+ );
assertEquals(6, shuffleBlockInfos.size());
for (ShuffleBlockInfo shuffleBlockInfo : shuffleBlockInfos) {
- // it can happen that 2 different platforms compress the same data
differently,
- // yet the decoded outcome remains identical to original.
- // https://github.com/lz4/lz4/issues/812
- if (TestUtils.isMacOnAppleSilicon()) {
- assertEquals(24, shuffleBlockInfo.getLength());
- } else {
- assertEquals(20, shuffleBlockInfo.getLength());
- }
assertEquals(22, shuffleBlockInfo.getUncompressLength());
assertEquals(0, shuffleBlockInfo.getShuffleId());
if (shuffleBlockInfo.getPartitionId() == 0) {
diff --git a/common/pom.xml b/common/pom.xml
index 20c5049b..f043eb9c 100644
--- a/common/pom.xml
+++ b/common/pom.xml
@@ -94,6 +94,11 @@
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-minicluster</artifactId>
</dependency>
+ <dependency>
+ <groupId>com.github.luben</groupId>
+ <artifactId>zstd-jni</artifactId>
+ <scope>provided</scope>
+ </dependency>
</dependencies>
<build>
diff --git
a/common/src/main/java/org/apache/uniffle/common/RssShuffleUtils.java
b/common/src/main/java/org/apache/uniffle/common/RssShuffleUtils.java
index 58db058e..42788aa8 100644
--- a/common/src/main/java/org/apache/uniffle/common/RssShuffleUtils.java
+++ b/common/src/main/java/org/apache/uniffle/common/RssShuffleUtils.java
@@ -22,44 +22,8 @@ import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import com.google.common.base.Preconditions;
-import net.jpountz.lz4.LZ4Compressor;
-import net.jpountz.lz4.LZ4Factory;
-import net.jpountz.lz4.LZ4FastDecompressor;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
public class RssShuffleUtils {
-
- private static final Logger LOG =
LoggerFactory.getLogger(RssShuffleUtils.class);
-
- public static byte[] compressData(byte[] data) {
- LZ4Compressor compressor = LZ4Factory.fastestInstance().fastCompressor();
- return compressor.compress(data);
- }
-
- public static byte[] decompressData(byte[] data, int uncompressLength) {
- LZ4FastDecompressor fastDecompressor =
LZ4Factory.fastestInstance().fastDecompressor();
- byte[] uncompressData = new byte[uncompressLength];
- fastDecompressor.decompress(data, 0, uncompressData, 0, uncompressLength);
- return uncompressData;
- }
-
- public static ByteBuffer decompressData(ByteBuffer data, int
uncompressLength) {
- return decompressData(data, uncompressLength, true);
- }
-
- public static ByteBuffer decompressData(ByteBuffer data, int
uncompressLength, boolean useDirectMem) {
- LZ4FastDecompressor fastDecompressor =
LZ4Factory.fastestInstance().fastDecompressor();
- ByteBuffer uncompressData;
- if (useDirectMem) {
- uncompressData = ByteBuffer.allocateDirect(uncompressLength);
- } else {
- uncompressData = ByteBuffer.allocate(uncompressLength);
- }
- fastDecompressor.decompress(data, data.position(), uncompressData, 0,
uncompressLength);
- return uncompressData;
- }
-
/**
* DirectByteBuffers are garbage collected by using a phantom reference and a
* reference queue. Every once a while, the JVM checks the reference queue
and
diff --git
a/common/src/main/java/org/apache/uniffle/common/compression/Codec.java
b/common/src/main/java/org/apache/uniffle/common/compression/Codec.java
new file mode 100644
index 00000000..9ff7d85d
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/compression/Codec.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.compression;
+
+import java.nio.ByteBuffer;
+
+import org.apache.uniffle.common.config.RssConf;
+
+import static org.apache.uniffle.common.config.RssClientConf.COMPRESSION_TYPE;
+import static
org.apache.uniffle.common.config.RssClientConf.ZSTD_COMPRESSION_LEVEL;
+
+public abstract class Codec {
+
+ public static Codec newInstance(RssConf rssConf) {
+ Type type = rssConf.get(COMPRESSION_TYPE);
+ switch (type) {
+ case ZSTD:
+ return new ZstdCodec(rssConf.get(ZSTD_COMPRESSION_LEVEL));
+ case NOOP:
+ return new NoOpCodec();
+ case LZ4:
+ default:
+ return new Lz4Codec();
+ }
+ }
+
+ public abstract void decompress(ByteBuffer src, int uncompressedLen,
ByteBuffer dest, int destOffset);
+
+ public abstract byte[] compress(byte[] src);
+
+ public enum Type {
+ LZ4,
+ ZSTD,
+ NOOP,
+ }
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/compression/Lz4Codec.java
b/common/src/main/java/org/apache/uniffle/common/compression/Lz4Codec.java
new file mode 100644
index 00000000..59b6df6f
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/compression/Lz4Codec.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.compression;
+
+import java.nio.ByteBuffer;
+
+import net.jpountz.lz4.LZ4Factory;
+
+public class Lz4Codec extends Codec {
+
+ private LZ4Factory lz4Factory;
+
+ public Lz4Codec() {
+ this.lz4Factory = LZ4Factory.fastestInstance();
+ }
+
+ @Override
+ public void decompress(ByteBuffer src, int uncompressedLen, ByteBuffer dest,
int destOffset) {
+ lz4Factory.fastDecompressor().decompress(src, src.position(), dest,
destOffset, uncompressedLen);
+ }
+
+ @Override
+ public byte[] compress(byte[] src) {
+ return lz4Factory.fastCompressor().compress(src);
+ }
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/compression/NoOpCodec.java
b/common/src/main/java/org/apache/uniffle/common/compression/NoOpCodec.java
new file mode 100644
index 00000000..99c7cb4e
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/compression/NoOpCodec.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.compression;
+
+import java.nio.ByteBuffer;
+
+public class NoOpCodec extends Codec {
+
+ @Override
+ public void decompress(ByteBuffer src, int uncompressedLen, ByteBuffer dest,
int destOffset) {
+ dest.put(src);
+ }
+
+ @Override
+ public byte[] compress(byte[] src) {
+ byte[] dst = new byte[src.length];
+ System.arraycopy(src, 0, dst, 0, src.length);
+ return dst;
+ }
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/compression/ZstdCodec.java
b/common/src/main/java/org/apache/uniffle/common/compression/ZstdCodec.java
new file mode 100644
index 00000000..0c596af8
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/compression/ZstdCodec.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.compression;
+
+import java.nio.ByteBuffer;
+
+import com.github.luben.zstd.Zstd;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.exception.RssException;
+
+public class ZstdCodec extends Codec {
+ private static final Logger LOGGER =
LoggerFactory.getLogger(ZstdCodec.class);
+
+ private final int compressionLevel;
+
+ public ZstdCodec(int level) {
+ this.compressionLevel = level;
+ LOGGER.info("Initializing zstd compressor.");
+ }
+
+ @Override
+ public void decompress(ByteBuffer src, int uncompressedLen, ByteBuffer dst,
int dstOffset) {
+ if (src.isDirect() && dst.isDirect()) {
+ long size = Zstd.decompressDirectByteBuffer(
+ dst, dstOffset, uncompressedLen,
+ src, src.position(), src.limit() - src.position()
+ );
+ if (size != uncompressedLen) {
+ throw new RssException(
+ "This should not happen that the decompressed data size is not
equals to original size.");
+ }
+ return;
+ }
+
+ if (!src.isDirect() && !dst.isDirect()) {
+ Zstd.decompressByteArray(
+ dst.array(), dstOffset, uncompressedLen,
+ src.array(), src.position(), src.limit() - src.position()
+ );
+ return;
+ }
+
+ throw new IllegalStateException("Zstd only supports the same type of
bytebuffer decompression.");
+ }
+
+ @Override
+ public byte[] compress(byte[] src) {
+ return Zstd.compress(src, compressionLevel);
+ }
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
new file mode 100644
index 00000000..99d82e03
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.config;
+
+import org.apache.uniffle.common.compression.Codec;
+
+import static org.apache.uniffle.common.compression.Codec.Type.LZ4;
+
+public class RssClientConf {
+
+ public static final ConfigOption<Codec.Type> COMPRESSION_TYPE = ConfigOptions
+ .key("rss.client.io.compression.codec")
+ .enumType(Codec.Type.class)
+ .defaultValue(LZ4)
+ .withDescription("The compression codec is used to compress the shuffle
data. "
+ + "Default codec is `LZ4`, `ZSTD` also can be used.");
+
+ public static final ConfigOption<Integer> ZSTD_COMPRESSION_LEVEL =
ConfigOptions
+ .key("rss.client.io.compression.zstd.level")
+ .intType()
+ .defaultValue(3)
+ .withDescription("The zstd compression level, the default level is 3");
+}
diff --git
a/common/src/test/java/org/apache/uniffle/common/RssShuffleUtilsTest.java
b/common/src/test/java/org/apache/uniffle/common/RssShuffleUtilsTest.java
deleted file mode 100644
index f6f00a17..00000000
--- a/common/src/test/java/org/apache/uniffle/common/RssShuffleUtilsTest.java
+++ /dev/null
@@ -1,99 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.uniffle.common;
-
-import java.lang.reflect.Field;
-import java.nio.Buffer;
-import java.nio.ByteBuffer;
-
-import org.apache.commons.lang3.RandomUtils;
-import org.junit.jupiter.api.Test;
-import org.junit.jupiter.params.ParameterizedTest;
-import org.junit.jupiter.params.provider.ValueSource;
-import sun.misc.Unsafe;
-
-import static org.junit.jupiter.api.Assertions.assertArrayEquals;
-import static org.junit.jupiter.api.Assertions.assertFalse;
-
-public class RssShuffleUtilsTest {
-
- @ParameterizedTest
- @ValueSource(ints = {1, 1024, 128 * 1024, 512 * 1024, 1024 * 1024, 4 * 1024
* 1024})
- public void testCompression(int size) {
- byte[] data = RandomUtils.nextBytes(size);
- byte[] compressed = RssShuffleUtils.compressData(data);
- byte[] decompressed = RssShuffleUtils.decompressData(compressed, size);
- assertArrayEquals(data, decompressed);
-
- ByteBuffer decompressedBB =
RssShuffleUtils.decompressData(ByteBuffer.wrap(compressed), size);
- byte[] buffer = new byte[size];
- decompressedBB.get(buffer);
- assertArrayEquals(data, buffer);
-
- ByteBuffer decompressedBB2 =
RssShuffleUtils.decompressData(ByteBuffer.wrap(compressed), size, false);
- byte[] buffer2 = new byte[size];
- decompressedBB2.get(buffer2);
- assertArrayEquals(data, buffer2);
- }
-
- @Test
- public void testDestroyDirectByteBuffer() throws Exception {
- int size = 10;
- byte b = 1;
- ByteBuffer byteBuffer = ByteBuffer.allocateDirect(size);
- for (int i = 0; i < size; i++) {
- byteBuffer.put(b);
- }
- byteBuffer.flip();
-
- // Get valid native pointer through `address` in `DirectByteBuffer`
- Unsafe unsafe = getUnsafe();
- long addressInByteBuffer = address(byteBuffer);
- long originalAddress = unsafe.getAddress(addressInByteBuffer);
-
- RssShuffleUtils.destroyDirectByteBuffer(byteBuffer);
-
- // The memory may not be released fast enough.
- // If native pointer changes, `address` in `DirectByteBuffer` is invalid
- while (unsafe.getAddress(addressInByteBuffer) == originalAddress) {
- Thread.sleep(200);
- }
- boolean same = true;
- byte[] read = new byte[size];
- byteBuffer.get(read);
- for (byte br : read) {
- if (b != br) {
- same = false;
- break;
- }
- }
- assertFalse(same);
- }
-
- private Unsafe getUnsafe() throws NoSuchFieldException,
IllegalAccessException {
- Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe");
- unsafeField.setAccessible(true);
- return (Unsafe) unsafeField.get(null);
- }
-
- private long address(ByteBuffer buffer) throws NoSuchFieldException,
IllegalAccessException {
- Field addressField = Buffer.class.getDeclaredField("address");
- addressField.setAccessible(true);
- return (long) addressField.get(buffer);
- }
-}
diff --git
a/common/src/test/java/org/apache/uniffle/common/compression/CompressionTest.java
b/common/src/test/java/org/apache/uniffle/common/compression/CompressionTest.java
new file mode 100644
index 00000000..cb2fdc6f
--- /dev/null
+++
b/common/src/test/java/org/apache/uniffle/common/compression/CompressionTest.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.compression;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.commons.lang3.RandomUtils;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+
+import org.apache.uniffle.common.config.RssConf;
+
+import static org.apache.uniffle.common.config.RssClientConf.COMPRESSION_TYPE;
+import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+
+public class CompressionTest {
+
+ static List<Arguments> testCompression() {
+ int[] sizes = {1, 1024, 128 * 1024, 512 * 1024, 1024 * 1024, 4 * 1024 *
1024};
+ Codec.Type[] types = {Codec.Type.ZSTD, Codec.Type.LZ4};
+
+ List<Arguments> arguments = new ArrayList<>();
+ for (int size : sizes) {
+ for (Codec.Type type : types) {
+ arguments.add(
+ Arguments.of(size, type)
+ );
+ }
+ }
+ return arguments;
+ }
+
+ @ParameterizedTest
+ @MethodSource
+ public void testCompression(int size, Codec.Type type) {
+ byte[] data = RandomUtils.nextBytes(size);
+ RssConf conf = new RssConf();
+ conf.set(COMPRESSION_TYPE, type);
+
+ // case1: heap bytebuffer
+ Codec codec = Codec.newInstance(conf);
+ byte[] compressed = codec.compress(data);
+
+ ByteBuffer dest = ByteBuffer.allocate(size);
+ codec.decompress(ByteBuffer.wrap(compressed), size, dest, 0);
+
+ assertArrayEquals(data, dest.array());
+
+ // case2: non-heap bytebuffer
+ ByteBuffer src = ByteBuffer.allocateDirect(compressed.length);
+ src.put(compressed);
+ src.flip();
+ ByteBuffer dst = ByteBuffer.allocateDirect(size);
+ codec.decompress(src, size, dst, 0);
+ byte[] res = new byte[size];
+ dst.get(res);
+ assertArrayEquals(data, res);
+
+ // case3: use the recycled bytebuffer
+ ByteBuffer recycledDst = ByteBuffer.allocate(size + 10);
+ codec.decompress(ByteBuffer.wrap(compressed), size, recycledDst, 0);
+ recycledDst.get(res);
+ assertArrayEquals(data, res);
+ }
+}
diff --git a/docs/client_guide.md b/docs/client_guide.md
index 9b5a208a..c945802d 100644
--- a/docs/client_guide.md
+++ b/docs/client_guide.md
@@ -89,6 +89,8 @@ These configurations are shared by all types of clients.
|<client_type>.rss.client.assignment.tags|-|The comma-separated list of tags
for deciding assignment shuffle servers. Notice that the SHUFFLE_SERVER_VERSION
will always as the assignment tag whether this conf is set or not|
|<client_type>.rss.client.data.commit.pool.size|The number of assigned shuffle
servers|The thread size for sending commit to shuffle servers|
|<client_type>.rss.client.assignment.shuffle.nodes.max|-1|The number of
required assignment shuffle servers. If it is less than 0 or equals to 0 or
greater than the coordinator's config of "rss.coordinator.shuffle.nodes.max",
it will use the size of "rss.coordinator.shuffle.nodes.max" default|
+|<client_type>.rss.client.io.compression.codec|lz4|The compression codec is
used to compress the shuffle data. Default codec is `lz4`, `zstd` also can be
used.|
+|<client_type>.rss.client.io.compression.zstd.level|3|The zstd compression
level, the default level is 3|
Notice:
1. `<client_type>` should be `spark` or `mapreduce`
diff --git
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RepartitionWithLocalFileRssTest.java
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RepartitionWithLocalFileRssTest.java
index 82586d91..f97f043e 100644
---
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RepartitionWithLocalFileRssTest.java
+++
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RepartitionWithLocalFileRssTest.java
@@ -18,18 +18,25 @@
package org.apache.uniffle.test;
import java.io.File;
+import java.util.ArrayList;
+import java.util.List;
import java.util.Map;
+import java.util.concurrent.TimeUnit;
import com.google.common.collect.Maps;
import com.google.common.io.Files;
+import com.google.common.util.concurrent.Uninterruptibles;
import org.apache.spark.SparkConf;
import org.apache.spark.shuffle.RssSparkConfig;
import org.junit.jupiter.api.BeforeAll;
+import org.apache.uniffle.common.compression.Codec;
import org.apache.uniffle.coordinator.CoordinatorConf;
import org.apache.uniffle.server.ShuffleServerConf;
import org.apache.uniffle.storage.util.StorageType;
+import static org.apache.uniffle.common.config.RssClientConf.COMPRESSION_TYPE;
+
public class RepartitionWithLocalFileRssTest extends RepartitionTest {
@BeforeAll
@@ -53,4 +60,35 @@ public class RepartitionWithLocalFileRssTest extends
RepartitionTest {
@Override
public void updateRssStorage(SparkConf sparkConf) {
}
+
+ /**
+ * Test different compression types with localfile rss mode.
+ * @throws Exception
+ */
+ @Override
+ public void run() throws Exception {
+ String fileName = generateTestFile();
+ SparkConf sparkConf = createSparkConf();
+ Uninterruptibles.sleepUninterruptibly(2, TimeUnit.SECONDS);
+
+ List<Map> results = new ArrayList<>();
+ Map resultWithoutRss = runSparkApp(sparkConf, fileName);
+ results.add(resultWithoutRss);
+
+ updateSparkConfWithRss(sparkConf);
+ updateSparkConfCustomer(sparkConf);
+ for (Codec.Type type :
+ new Codec.Type[]{
+ Codec.Type.NOOP,
+ Codec.Type.ZSTD,
+ Codec.Type.LZ4}) {
+ sparkConf.set("spark." + COMPRESSION_TYPE.key().toLowerCase(),
type.name());
+ Map resultWithRss = runSparkApp(sparkConf, fileName);
+ results.add(resultWithRss);
+ }
+
+ for (int i = 1; i < results.size(); i++) {
+ verifyTestResult(results.get(0), results.get(i));
+ }
+ }
}
diff --git
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkIntegrationTestBase.java
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkIntegrationTestBase.java
index 1ea90007..0cc5a17d 100644
---
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkIntegrationTestBase.java
+++
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkIntegrationTestBase.java
@@ -99,7 +99,7 @@ public abstract class SparkIntegrationTestBase extends
IntegrationTestBase {
sparkConf.set(RssSparkConfig.RSS_HEARTBEAT_INTERVAL.key(), "2000");
}
- private void verifyTestResult(Map expected, Map actual) {
+ protected void verifyTestResult(Map expected, Map actual) {
assertEquals(expected.size(), actual.size());
for (Object expectedKey : expected.keySet()) {
assertEquals(expected.get(expectedKey), actual.get(expectedKey));
diff --git a/pom.xml b/pom.xml
index c18d1be7..61e999c2 100644
--- a/pom.xml
+++ b/pom.xml
@@ -75,6 +75,7 @@
<spotbugs.version>4.7.0</spotbugs.version>
<spotbugs-maven-plugin.version>4.7.0.0</spotbugs-maven-plugin.version>
<system-rules.version>1.19.0</system-rules.version>
+ <zstd-jni.version>1.5.2-3</zstd-jni.version>
<test.redirectToFile>true</test.redirectToFile>
<trimStackTrace>false</trimStackTrace>
</properties>
@@ -600,6 +601,12 @@
<artifactId>mockito-core</artifactId>
<version>${mockito.version}</version>
</dependency>
+
+ <dependency>
+ <groupId>com.github.luben</groupId>
+ <artifactId>zstd-jni</artifactId>
+ <version>${zstd-jni.version}</version>
+ </dependency>
</dependencies>
</dependencyManagement>