This is an automated email from the ASF dual-hosted git repository.
roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 6bd8d133 [#133] feat(netty): Introduce ShuffleServerGrpcNettyClient.
(#839)
6bd8d133 is described below
commit 6bd8d1335b0de389defe836d550812d1d31bcd1f
Author: Xianming Lei <[email protected]>
AuthorDate: Mon May 8 14:56:53 2023 +0800
[#133] feat(netty): Introduce ShuffleServerGrpcNettyClient. (#839)
### What changes were proposed in this pull request?
Introduce ShuffleServerGrpcNettyClient.
### Why are the changes needed?
for #133
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
existing UTs.
Co-authored-by: leixianming <[email protected]>
---
.../org/apache/hadoop/mapreduce/RssMRUtils.java | 2 +-
.../hadoop/mapreduce/task/reduce/RssShuffle.java | 16 +-
.../apache/spark/shuffle/RssShuffleManager.java | 17 +-
.../spark/shuffle/reader/RssShuffleReader.java | 14 +-
.../spark/shuffle/reader/RssShuffleReaderTest.java | 8 +-
.../apache/spark/shuffle/RssShuffleManager.java | 18 +-
.../spark/shuffle/reader/RssShuffleReader.java | 16 +-
.../spark/shuffle/reader/RssShuffleReaderTest.java | 20 +-
.../client/factory/ShuffleClientFactory.java | 29 ++-
.../uniffle/client/impl/ShuffleReadClientImpl.java | 69 +++++-
.../client/impl/ShuffleWriteClientImpl.java | 30 ++-
.../request/CreateShuffleReadClientRequest.java | 69 +++---
.../apache/uniffle/common/ShuffleIndexResult.java | 12 +-
.../uniffle/common/config/RssClientConf.java | 17 ++
.../common/netty/client/TransportClient.java | 26 ++-
.../common/segment/FixedSizeSegmentSplitter.java | 19 +-
.../common/segment/LocalOrderSegmentSplitter.java | 17 +-
.../segment/FixedSizeSegmentSplitterTest.java | 7 +-
.../segment/LocalOrderSegmentSplitterTest.java | 28 +--
.../client/factory/ShuffleServerClientFactory.java | 17 +-
.../client/impl/grpc/ShuffleServerGrpcClient.java | 15 +-
.../impl/grpc/ShuffleServerGrpcNettyClient.java | 253 +++++++++++++++++++++
.../RssGetInMemoryShuffleDataResponse.java | 7 +-
.../client/response/RssGetShuffleDataResponse.java | 8 +-
.../response/RssGetShuffleIndexResponse.java | 4 +-
.../uniffle/server/ShuffleServerGrpcService.java | 9 +-
.../storage/factory/ShuffleHandlerFactory.java | 5 +-
.../handler/impl/HdfsShuffleReadHandler.java | 2 +-
.../handler/impl/LocalFileServerReadHandler.java | 3 +-
.../request/CreateShuffleReadHandlerRequest.java | 10 +
.../handler/impl/HdfsClientReadHandlerTest.java | 2 +-
.../impl/LocalFileServerReadHandlerTest.java | 7 +-
32 files changed, 573 insertions(+), 203 deletions(-)
diff --git
a/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java
b/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java
index 7a732353..419a2e49 100644
--- a/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java
+++ b/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java
@@ -98,7 +98,7 @@ public class RssMRUtils {
.getInstance()
.createShuffleWriteClient(clientType, retryMax, retryIntervalMax,
heartBeatThreadNum, replica, replicaWrite, replicaRead,
replicaSkipEnabled,
- dataTransferPoolSize, dataCommitPoolSize);
+ dataTransferPoolSize, dataCommitPoolSize,
RssMRConfig.toRssConf(jobConf));
return client;
}
diff --git
a/client-mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java
b/client-mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java
index d08a5b2b..0094ecd5 100644
---
a/client-mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java
+++
b/client-mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java
@@ -44,7 +44,6 @@ import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.request.CreateShuffleReadClientRequest;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
-import org.apache.uniffle.common.util.UnitConverter;
public class RssShuffle<K, V> implements ShuffleConsumerPlugin<K, V>,
ExceptionReporter {
@@ -69,7 +68,6 @@ public class RssShuffle<K, V> implements
ShuffleConsumerPlugin<K, V>, ExceptionR
private Task reduceTask; //Used for status updates
private String appId;
- private String storageType;
private String clientType;
private int replicaWrite;
private int replicaRead;
@@ -78,8 +76,6 @@ public class RssShuffle<K, V> implements
ShuffleConsumerPlugin<K, V>, ExceptionR
private int partitionNum;
private int partitionNumPerRange;
private String basePath;
- private int indexReadLimit;
- private int readBufferSize;
private RemoteStorageInfo remoteStorageInfo;
private int appAttemptId;
@@ -102,7 +98,6 @@ public class RssShuffle<K, V> implements
ShuffleConsumerPlugin<K, V>, ExceptionR
// rss init
this.appId = RssMRUtils.getApplicationAttemptId().toString();
this.appAttemptId = RssMRUtils.getApplicationAttemptId().getAttemptId();
- this.storageType = RssMRUtils.getString(rssJobConf, mrJobConf,
RssMRConfig.RSS_STORAGE_TYPE);
this.replicaWrite = RssMRUtils.getInt(rssJobConf, mrJobConf,
RssMRConfig.RSS_DATA_REPLICA_WRITE,
RssMRConfig.RSS_DATA_REPLICA_WRITE_DEFAULT_VALUE);
this.replicaRead = RssMRUtils.getInt(rssJobConf, mrJobConf,
RssMRConfig.RSS_DATA_REPLICA_READ,
@@ -114,11 +109,6 @@ public class RssShuffle<K, V> implements
ShuffleConsumerPlugin<K, V>, ExceptionR
this.partitionNumPerRange = RssMRUtils.getInt(rssJobConf, mrJobConf,
RssMRConfig.RSS_PARTITION_NUM_PER_RANGE,
RssMRConfig.RSS_PARTITION_NUM_PER_RANGE_DEFAULT_VALUE);
this.basePath = RssMRUtils.getString(rssJobConf, mrJobConf,
RssMRConfig.RSS_REMOTE_STORAGE_PATH);
- this.indexReadLimit = RssMRUtils.getInt(rssJobConf, mrJobConf,
RssMRConfig.RSS_INDEX_READ_LIMIT,
- RssMRConfig.RSS_INDEX_READ_LIMIT_DEFAULT_VALUE);
- this.readBufferSize = (int)UnitConverter.byteStringAsBytes(
- RssMRUtils.getString(rssJobConf, mrJobConf,
RssMRConfig.RSS_CLIENT_READ_BUFFER_SIZE,
- RssMRConfig.RSS_CLIENT_READ_BUFFER_SIZE_DEFAULT_VALUE));
String remoteStorageConf = RssMRUtils.getString(rssJobConf, mrJobConf,
RssMRConfig.RSS_REMOTE_STORAGE_CONF, "");
this.remoteStorageInfo = new RemoteStorageInfo(basePath,
remoteStorageConf);
this.merger = createMergeManager(context);
@@ -193,9 +183,9 @@ public class RssShuffle<K, V> implements
ShuffleConsumerPlugin<K, V>, ExceptionR
JobConf readerJobConf = getRemoteConf();
boolean expectedTaskIdsBitmapFilterEnable = serverInfoList.size() > 1;
CreateShuffleReadClientRequest request = new
CreateShuffleReadClientRequest(
- appId, 0, reduceId.getTaskID().getId(), storageType, basePath,
indexReadLimit, readBufferSize,
- partitionNumPerRange, partitionNum, blockIdBitmap, taskIdBitmap,
serverInfoList,
- readerJobConf, new MRIdHelper(), expectedTaskIdsBitmapFilterEnable,
false);
+ appId, 0, reduceId.getTaskID().getId(), basePath,
partitionNumPerRange,
+ partitionNum, blockIdBitmap, taskIdBitmap, serverInfoList,
readerJobConf,
+ new MRIdHelper(), expectedTaskIdsBitmapFilterEnable,
RssMRConfig.toRssConf(rssJobConf));
ShuffleReadClient shuffleReadClient =
ShuffleClientFactory.getInstance().createShuffleReadClient(request);
RssFetcher fetcher = new RssFetcher(mrJobConf, reduceId, taskStatus,
merger, copyPhase, reporter, metrics,
shuffleReadClient, blockIdBitmap.getLongCardinality(),
RssMRConfig.toRssConf(rssJobConf));
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 5a465fa5..f4740cf0 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -137,11 +137,12 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
this.dataCommitPoolSize =
sparkConf.get(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE);
int unregisterThreadPoolSize =
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE);
int unregisterRequestTimeoutSec =
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC);
+ RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
this.shuffleWriteClient = ShuffleClientFactory
.getInstance()
.createShuffleWriteClient(clientType, retryMax, retryIntervalMax,
heartBeatThreadNum,
dataReplica, dataReplicaWrite, dataReplicaRead,
dataReplicaSkipEnabled, dataTransferPoolSize,
- dataCommitPoolSize, unregisterThreadPoolSize,
unregisterRequestTimeoutSec);
+ dataCommitPoolSize, unregisterThreadPoolSize,
unregisterRequestTimeoutSec, rssConf);
registerCoordinator();
// fetch client conf and apply them if necessary and disable ESS
if (isDriver && dynamicConfEnabled) {
@@ -161,7 +162,6 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
if (isDriver) {
heartBeatScheduledExecutorService =
ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat");
- RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
if (rssConf.getBoolean(RssClientConfig.RSS_RESUBMIT_STAGE, false)
&& RssSparkShuffleUtils.isStageResubmitSupported()) {
LOG.info("stage resubmit is supported and enabled");
@@ -384,17 +384,9 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
public <K, C> ShuffleReader<K, C> getReader(ShuffleHandle handle,
int startPartition, int endPartition, TaskContext context) {
if (handle instanceof RssShuffleHandle) {
- final String storageType =
sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key());
- final int indexReadLimit =
sparkConf.get(RssSparkConfig.RSS_INDEX_READ_LIMIT);
RssShuffleHandle<K, C, ?> rssShuffleHandle = (RssShuffleHandle<K, C, ?>)
handle;
final int partitionNumPerRange =
sparkConf.get(RssSparkConfig.RSS_PARTITION_NUM_PER_RANGE);
final int partitionNum =
rssShuffleHandle.getDependency().partitioner().numPartitions();
- long readBufferSize =
sparkConf.getSizeAsBytes(RssSparkConfig.RSS_CLIENT_READ_BUFFER_SIZE.key(),
- RssSparkConfig.RSS_CLIENT_READ_BUFFER_SIZE.defaultValue().get());
- if (readBufferSize > Integer.MAX_VALUE) {
- LOG.warn(RssSparkConfig.RSS_CLIENT_READ_BUFFER_SIZE + " can support 2g
as max");
- readBufferSize = Integer.MAX_VALUE;
- }
int shuffleId = rssShuffleHandle.getShuffleId();
long start = System.currentTimeMillis();
Roaring64NavigableMap taskIdBitmap = getExpectedTasks(shuffleId,
startPartition, endPartition);
@@ -418,9 +410,8 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
return new RssShuffleReader<K, C>(
startPartition, endPartition, context,
- rssShuffleHandle, shuffleRemoteStoragePath, indexReadLimit,
- readerHadoopConf,
- storageType, (int) readBufferSize, partitionNumPerRange,
partitionNum,
+ rssShuffleHandle, shuffleRemoteStoragePath,
+ readerHadoopConf, partitionNumPerRange, partitionNum,
blockIdBitmap, taskIdBitmap, RssSparkConfig.toRssConf(sparkConf));
} else {
throw new RssException("Unexpected ShuffleHandle:" +
handle.getClass().getName());
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index ec976113..82285f3b 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -65,11 +65,8 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
private Serializer serializer;
private String taskId;
private String basePath;
- private int indexReadLimit;
- private int readBufferSize;
private int partitionNumPerRange;
private int partitionNum;
- private String storageType;
private Roaring64NavigableMap blockIdBitmap;
private Roaring64NavigableMap taskIdBitmap;
private List<ShuffleServerInfo> shuffleServerInfoList;
@@ -82,10 +79,7 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
TaskContext context,
RssShuffleHandle<K, C, ?> rssShuffleHandle,
String basePath,
- int indexReadLimit,
Configuration hadoopConf,
- String storageType,
- int readBufferSize,
int partitionNumPerRange,
int partitionNum,
Roaring64NavigableMap blockIdBitmap,
@@ -100,9 +94,6 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
this.serializer = rssShuffleHandle.getDependency().serializer();
this.taskId = "" + context.taskAttemptId() + "_" + context.attemptNumber();
this.basePath = basePath;
- this.indexReadLimit = indexReadLimit;
- this.storageType = storageType;
- this.readBufferSize = readBufferSize;
this.partitionNumPerRange = partitionNumPerRange;
this.partitionNum = partitionNum;
this.blockIdBitmap = blockIdBitmap;
@@ -119,10 +110,9 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
LOG.info("Shuffle read started:" + getReadInfo());
CreateShuffleReadClientRequest request = new
CreateShuffleReadClientRequest(
- appId, shuffleId, startPartition, storageType, basePath,
indexReadLimit, readBufferSize,
+ appId, shuffleId, startPartition, basePath,
partitionNumPerRange, partitionNum, blockIdBitmap, taskIdBitmap,
- shuffleServerInfoList, hadoopConf, expectedTaskIdsBitmapFilterEnable,
- rssConf.getBoolean(RssClientConf.OFF_HEAP_MEMORY_ENABLE));
+ shuffleServerInfoList, hadoopConf, expectedTaskIdsBitmapFilterEnable,
rssConf);
ShuffleReadClient shuffleReadClient =
ShuffleClientFactory.getInstance().createShuffleReadClient(request);
RssShuffleDataIterator rssShuffleDataIterator = new
RssShuffleDataIterator<K, C>(
shuffleDependency.serializer(), shuffleReadClient,
diff --git
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
index 4d23ea5d..61a7fa07 100644
---
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
+++
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
@@ -36,6 +36,7 @@ import org.junit.jupiter.api.Test;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.storage.handler.impl.HdfsShuffleWriteHandler;
import org.apache.uniffle.storage.util.StorageType;
@@ -81,9 +82,12 @@ public class RssShuffleReaderTest extends
AbstractRssReaderTest {
when(dependencyMock.aggregator()).thenReturn(Option.empty());
when(dependencyMock.keyOrdering()).thenReturn(Option.empty());
+ RssConf rssConf = new RssConf();
+ rssConf.set(RssClientConf.RSS_STORAGE_TYPE, StorageType.HDFS.name());
+ rssConf.set(RssClientConf.RSS_INDEX_READ_LIMIT, 1000);
+ rssConf.set(RssClientConf.RSS_CLIENT_READ_BUFFER_SIZE, "1000");
RssShuffleReader<String, String> rssShuffleReaderSpy = spy(new
RssShuffleReader<>(0, 1, contextMock,
- handleMock, basePath, 1000, conf, StorageType.HDFS.name(),
- 1000, 2, 10, blockIdBitmap, taskIdBitmap, new RssConf()));
+ handleMock, basePath, conf, 2, 10, blockIdBitmap, taskIdBitmap,
rssConf));
validateResult(rssShuffleReaderSpy.read(), expectedData, 10);
}
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 57c8d92f..d5fec8cb 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -151,11 +151,12 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
this.dataCommitPoolSize =
sparkConf.get(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE);
int unregisterThreadPoolSize =
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE);
int unregisterRequestTimeoutSec =
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC);
+ RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
shuffleWriteClient = ShuffleClientFactory
.getInstance()
.createShuffleWriteClient(clientType, retryMax, retryIntervalMax,
heartBeatThreadNum,
dataReplica, dataReplicaWrite, dataReplicaRead,
dataReplicaSkipEnabled, dataTransferPoolSize,
- dataCommitPoolSize, unregisterThreadPoolSize,
unregisterRequestTimeoutSec);
+ dataCommitPoolSize, unregisterThreadPoolSize,
unregisterRequestTimeoutSec, rssConf);
registerCoordinator();
// fetch client conf and apply them if necessary and disable ESS
if (isDriver && dynamicConfEnabled) {
@@ -179,7 +180,6 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
if (isDriver) {
heartBeatScheduledExecutorService =
ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat");
- RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
if (rssConf.getBoolean(RssClientConfig.RSS_RESUBMIT_STAGE, false)
&& RssSparkShuffleUtils.isStageResubmitSupported()) {
LOG.info("stage resubmit is supported and enabled");
@@ -273,7 +273,8 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
dataTransferPoolSize,
dataCommitPoolSize,
unregisterThreadPoolSize,
- unregisterRequestTimeoutSec
+ unregisterRequestTimeoutSec,
+ RssSparkConfig.toRssConf(sparkConf)
);
this.taskToSuccessBlockIds = taskToSuccessBlockIds;
this.taskToFailedBlockIds = taskToFailedBlockIds;
@@ -473,16 +474,8 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
if (!(handle instanceof RssShuffleHandle)) {
throw new RssException("Unexpected ShuffleHandle:" +
handle.getClass().getName());
}
- final String storageType =
sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key());
- final int indexReadLimit =
sparkConf.get(RssSparkConfig.RSS_INDEX_READ_LIMIT);
RssShuffleHandle<K, C, ?> rssShuffleHandle = (RssShuffleHandle<K, C, ?>)
handle;
final int partitionNum =
rssShuffleHandle.getDependency().partitioner().numPartitions();
- long readBufferSize =
sparkConf.getSizeAsBytes(RssSparkConfig.RSS_CLIENT_READ_BUFFER_SIZE.key(),
- RssSparkConfig.RSS_CLIENT_READ_BUFFER_SIZE.defaultValue().get());
- if (readBufferSize > Integer.MAX_VALUE) {
- LOG.warn(RssSparkConfig.RSS_CLIENT_READ_BUFFER_SIZE.key() + " can
support 2g as max");
- readBufferSize = Integer.MAX_VALUE;
- }
int shuffleId = rssShuffleHandle.getShuffleId();
Map<Integer, List<ShuffleServerInfo>> allPartitionToServers =
rssShuffleHandle.getPartitionToServers();
Map<Integer, List<ShuffleServerInfo>> requirePartitionToServers =
allPartitionToServers.entrySet()
@@ -518,10 +511,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
context,
rssShuffleHandle,
shuffleRemoteStoragePath,
- indexReadLimit,
readerHadoopConf,
- storageType,
- (int) readBufferSize,
partitionNum,
RssUtils.generatePartitionToBitmap(blockIdBitmap, startPartition,
endPartition),
taskIdBitmap,
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index 7cba5051..b9c8e36f 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -69,10 +69,7 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
private Serializer serializer;
private String taskId;
private String basePath;
- private int indexReadLimit;
- private int readBufferSize;
private int partitionNum;
- private String storageType;
private Map<Integer, Roaring64NavigableMap> partitionToExpectBlocks;
private Roaring64NavigableMap taskIdBitmap;
private Configuration hadoopConf;
@@ -90,10 +87,7 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
TaskContext context,
RssShuffleHandle<K, C, ?> rssShuffleHandle,
String basePath,
- int indexReadLimit,
Configuration hadoopConf,
- String storageType,
- int readBufferSize,
int partitionNum,
Map<Integer, Roaring64NavigableMap> partitionToExpectBlocks,
Roaring64NavigableMap taskIdBitmap,
@@ -111,9 +105,6 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
this.serializer = rssShuffleHandle.getDependency().serializer();
this.taskId = "" + context.taskAttemptId() + "_" + context.attemptNumber();
this.basePath = basePath;
- this.indexReadLimit = indexReadLimit;
- this.storageType = storageType;
- this.readBufferSize = readBufferSize;
this.partitionNum = partitionNum;
this.partitionToExpectBlocks = partitionToExpectBlocks;
this.taskIdBitmap = taskIdBitmap;
@@ -225,10 +216,9 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
boolean expectedTaskIdsBitmapFilterEnable = !(mapStartIndex == 0 &&
mapEndIndex == Integer.MAX_VALUE)
|| shuffleServerInfoList.size() > 1;
CreateShuffleReadClientRequest request = new
CreateShuffleReadClientRequest(
- appId, shuffleId, partition, storageType, basePath,
indexReadLimit, readBufferSize,
- 1, partitionNum, partitionToExpectBlocks.get(partition),
taskIdBitmap, shuffleServerInfoList,
- hadoopConf, dataDistributionType,
expectedTaskIdsBitmapFilterEnable,
- rssConf.getBoolean(RssClientConf.OFF_HEAP_MEMORY_ENABLE));
+ appId, shuffleId, partition, basePath, 1, partitionNum,
+ partitionToExpectBlocks.get(partition), taskIdBitmap,
shuffleServerInfoList, hadoopConf,
+ dataDistributionType, expectedTaskIdsBitmapFilterEnable, rssConf);
ShuffleReadClient shuffleReadClient =
ShuffleClientFactory.getInstance().createShuffleReadClient(request);
RssShuffleDataIterator<K, C> iterator = new RssShuffleDataIterator<>(
shuffleDependency.serializer(), shuffleReadClient,
diff --git
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
index 62cd6fc2..83593976 100644
---
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
+++
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
@@ -38,6 +38,7 @@ import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.storage.handler.impl.HdfsShuffleWriteHandler;
import org.apache.uniffle.storage.util.StorageType;
@@ -88,6 +89,10 @@ public class RssShuffleReaderTest extends
AbstractRssReaderTest {
Map<Integer, Roaring64NavigableMap> partitionToExpectBlocks =
Maps.newHashMap();
partitionToExpectBlocks.put(0, blockIdBitmap);
+ RssConf rssConf = new RssConf();
+ rssConf.set(RssClientConf.RSS_STORAGE_TYPE, StorageType.HDFS.name());
+ rssConf.set(RssClientConf.RSS_INDEX_READ_LIMIT, 1000);
+ rssConf.set(RssClientConf.RSS_CLIENT_READ_BUFFER_SIZE, "1000");
RssShuffleReader<String, String> rssShuffleReaderSpy = spy(new
RssShuffleReader<>(
0,
1,
@@ -96,15 +101,12 @@ public class RssShuffleReaderTest extends
AbstractRssReaderTest {
contextMock,
handleMock,
basePath,
- 1000,
conf,
- StorageType.HDFS.name(),
- 1000,
1,
partitionToExpectBlocks,
taskIdBitmap,
new ShuffleReadMetrics(),
- new RssConf(),
+ rssConf,
ShuffleDataDistributionType.NORMAL
));
validateResult(rssShuffleReaderSpy.read(), expectedData, 10);
@@ -120,15 +122,12 @@ public class RssShuffleReaderTest extends
AbstractRssReaderTest {
contextMock,
handleMock,
basePath,
- 1000,
conf,
- StorageType.HDFS.name(),
- 1000,
2,
partitionToExpectBlocks,
taskIdBitmap,
new ShuffleReadMetrics(),
- new RssConf(),
+ rssConf,
ShuffleDataDistributionType.NORMAL
));
validateResult(rssShuffleReaderSpy1.read(), expectedData, 18);
@@ -141,15 +140,12 @@ public class RssShuffleReaderTest extends
AbstractRssReaderTest {
contextMock,
handleMock,
basePath,
- 1000,
conf,
- StorageType.HDFS.name(),
- 1000,
2,
partitionToExpectBlocks,
Roaring64NavigableMap.bitmapOf(),
new ShuffleReadMetrics(),
- new RssConf(),
+ rssConf,
ShuffleDataDistributionType.NORMAL
));
validateResult(rssShuffleReaderSpy2.read(), Maps.newHashMap(), 0);
diff --git
a/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
b/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
index 710d9a13..16a9c7ae 100644
---
a/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
+++
b/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
@@ -22,6 +22,7 @@ import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.impl.ShuffleReadClientImpl;
import org.apache.uniffle.client.impl.ShuffleWriteClientImpl;
import org.apache.uniffle.client.request.CreateShuffleReadClientRequest;
+import org.apache.uniffle.common.config.RssConf;
public class ShuffleClientFactory {
@@ -42,13 +43,31 @@ public class ShuffleClientFactory {
int replica, int replicaWrite, int replicaRead, boolean
replicaSkipEnabled, int dataTransferPoolSize,
int dataCommitPoolSize) {
return createShuffleWriteClient(clientType, retryMax, retryIntervalMax,
heartBeatThreadNum, replica,
- replicaWrite, replicaRead, replicaSkipEnabled, dataTransferPoolSize,
dataCommitPoolSize, 10, 10);
+ replicaWrite, replicaRead, replicaSkipEnabled, dataTransferPoolSize,
dataCommitPoolSize,
+ 10, 10, new RssConf());
+ }
+
+ public ShuffleWriteClient createShuffleWriteClient(
+ String clientType, int retryMax, long retryIntervalMax, int
heartBeatThreadNum,
+ int replica, int replicaWrite, int replicaRead, boolean
replicaSkipEnabled, int dataTransferPoolSize,
+ int dataCommitPoolSize, RssConf rssConf) {
+ return createShuffleWriteClient(clientType, retryMax, retryIntervalMax,
heartBeatThreadNum, replica,
+ replicaWrite, replicaRead, replicaSkipEnabled, dataTransferPoolSize,
dataCommitPoolSize, 10, 10, rssConf);
}
public ShuffleWriteClient createShuffleWriteClient(
String clientType, int retryMax, long retryIntervalMax, int
heartBeatThreadNum,
int replica, int replicaWrite, int replicaRead, boolean
replicaSkipEnabled, int dataTransferPoolSize,
int dataCommitPoolSize, int unregisterThreadPoolSize, int
unregisterRequestTimeoutSec) {
+ return createShuffleWriteClient(clientType, retryMax, retryIntervalMax,
heartBeatThreadNum, replica,
+ replicaWrite, replicaRead, replicaSkipEnabled, dataTransferPoolSize,
dataCommitPoolSize,
+ unregisterThreadPoolSize, unregisterRequestTimeoutSec, new RssConf());
+ }
+
+ public ShuffleWriteClient createShuffleWriteClient(
+ String clientType, int retryMax, long retryIntervalMax, int
heartBeatThreadNum,
+ int replica, int replicaWrite, int replicaRead, boolean
replicaSkipEnabled, int dataTransferPoolSize,
+ int dataCommitPoolSize, int unregisterThreadPoolSize, int
unregisterRequestTimeoutSec, RssConf rssConf) {
// If replica > replicaWrite, blocks maybe be sent for 2 rounds.
// We need retry less times in this case for let the first round fail fast.
if (replicaSkipEnabled && replica > replicaWrite) {
@@ -66,20 +85,18 @@ public class ShuffleClientFactory {
dataTransferPoolSize,
dataCommitPoolSize,
unregisterThreadPoolSize,
- unregisterRequestTimeoutSec
+ unregisterRequestTimeoutSec,
+ rssConf
);
}
public ShuffleReadClient
createShuffleReadClient(CreateShuffleReadClientRequest request) {
return new ShuffleReadClientImpl(
- request.getStorageType(),
request.getAppId(),
request.getShuffleId(),
request.getPartitionId(),
- request.getIndexReadLimit(),
request.getPartitionNumPerRange(),
request.getPartitionNum(),
- request.getReadBufferSize(),
request.getBasePath(),
request.getBlockIdBitmap(),
request.getTaskIdBitmap(),
@@ -88,7 +105,7 @@ public class ShuffleClientFactory {
request.getIdHelper(),
request.getShuffleDataDistributionType(),
request.isExpectedTaskIdsBitmapFilterEnable(),
- request.isOffHeapEnabled()
+ request.getRssConf()
);
}
}
diff --git
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
index 4773e30e..958d5f6f 100644
---
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
+++
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
@@ -36,6 +36,8 @@ import org.apache.uniffle.common.BufferSegment;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleDataResult;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssClientConf;
+import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssFetchFailedException;
import org.apache.uniffle.common.util.ChecksumUtils;
import org.apache.uniffle.common.util.IdHelper;
@@ -47,7 +49,7 @@ import
org.apache.uniffle.storage.request.CreateShuffleReadHandlerRequest;
public class ShuffleReadClientImpl implements ShuffleReadClient {
private static final Logger LOG =
LoggerFactory.getLogger(ShuffleReadClientImpl.class);
- private final List<ShuffleServerInfo> shuffleServerInfoList;
+ private List<ShuffleServerInfo> shuffleServerInfoList;
private int shuffleId;
private int partitionId;
private ByteBuffer readBuffer;
@@ -60,10 +62,57 @@ public class ShuffleReadClientImpl implements
ShuffleReadClient {
private AtomicLong copyTime = new AtomicLong(0);
private AtomicLong crcCheckTime = new AtomicLong(0);
private ClientReadHandler clientReadHandler;
- private final IdHelper idHelper;
+ private IdHelper idHelper;
public ShuffleReadClientImpl(
- String storageType,
+ String appId,
+ int shuffleId,
+ int partitionId,
+ int partitionNumPerRange,
+ int partitionNum,
+ String storageBasePath,
+ Roaring64NavigableMap blockIdBitmap,
+ Roaring64NavigableMap taskIdBitmap,
+ List<ShuffleServerInfo> shuffleServerInfoList,
+ Configuration hadoopConf,
+ IdHelper idHelper,
+ ShuffleDataDistributionType dataDistributionType,
+ boolean expectedTaskIdsBitmapFilterEnable,
+ RssConf rssConf) {
+ final int indexReadLimit = rssConf.get(RssClientConf.RSS_INDEX_READ_LIMIT);
+ final String storageType = rssConf.get(RssClientConf.RSS_STORAGE_TYPE);
+ long readBufferSize =
rssConf.getSizeAsBytes(RssClientConf.RSS_CLIENT_READ_BUFFER_SIZE.key(),
+ RssClientConf.RSS_CLIENT_READ_BUFFER_SIZE.defaultValue());
+ if (readBufferSize > Integer.MAX_VALUE) {
+ LOG.warn(RssClientConf.RSS_CLIENT_READ_BUFFER_SIZE.key() + " can support
2g as max");
+ readBufferSize = Integer.MAX_VALUE;
+ }
+ boolean offHeapEnabled = rssConf.get(RssClientConf.OFF_HEAP_MEMORY_ENABLE);
+ init(storageType, appId, shuffleId, partitionId, indexReadLimit,
partitionNumPerRange, partitionNum,
+ (int) readBufferSize, storageBasePath, blockIdBitmap, taskIdBitmap,
shuffleServerInfoList, hadoopConf,
+ idHelper, dataDistributionType, expectedTaskIdsBitmapFilterEnable,
offHeapEnabled, rssConf);
+ }
+
+ public ShuffleReadClientImpl(
+ String appId,
+ int shuffleId,
+ int partitionId,
+ int partitionNumPerRange,
+ int partitionNum,
+ String storageBasePath,
+ Roaring64NavigableMap blockIdBitmap,
+ Roaring64NavigableMap taskIdBitmap,
+ List<ShuffleServerInfo> shuffleServerInfoList,
+ Configuration hadoopConf,
+ IdHelper idHelper,
+ RssConf rssConf) {
+ this(appId, shuffleId, partitionId, partitionNumPerRange,
+ partitionNum, storageBasePath, blockIdBitmap, taskIdBitmap,
+ shuffleServerInfoList, hadoopConf, idHelper,
+ ShuffleDataDistributionType.NORMAL, false, rssConf);
+ }
+
+ private void init(String storageType,
String appId,
int shuffleId,
int partitionId,
@@ -79,7 +128,8 @@ public class ShuffleReadClientImpl implements
ShuffleReadClient {
IdHelper idHelper,
ShuffleDataDistributionType dataDistributionType,
boolean expectedTaskIdsBitmapFilterEnable,
- boolean offHeapEnabled) {
+ boolean offHeapEnabled,
+ RssConf rssConf) {
this.shuffleId = shuffleId;
this.partitionId = partitionId;
this.blockIdBitmap = blockIdBitmap;
@@ -95,7 +145,7 @@ public class ShuffleReadClientImpl implements
ShuffleReadClient {
request.setIndexReadLimit(indexReadLimit);
request.setPartitionNumPerRange(partitionNumPerRange);
request.setPartitionNum(partitionNum);
- request.setReadBufferSize(readBufferSize);
+ request.setReadBufferSize((int) readBufferSize);
request.setStorageBasePath(storageBasePath);
request.setShuffleServerInfoList(shuffleServerInfoList);
request.setHadoopConf(hadoopConf);
@@ -104,6 +154,7 @@ public class ShuffleReadClientImpl implements
ShuffleReadClient {
request.setDistributionType(dataDistributionType);
request.setIdHelper(idHelper);
request.setExpectTaskIds(taskIdBitmap);
+ request.setClientConf(rssConf);
if (expectedTaskIdsBitmapFilterEnable) {
request.useExpectedTaskIdsBitmapFilter();
}
@@ -143,10 +194,14 @@ public class ShuffleReadClientImpl implements
ShuffleReadClient {
List<ShuffleServerInfo> shuffleServerInfoList,
Configuration hadoopConf,
IdHelper idHelper) {
- this(storageType, appId, shuffleId, partitionId, indexReadLimit,
+ RssConf rssConf = new RssConf();
+ rssConf.set(RssClientConf.RSS_STORAGE_TYPE, storageType);
+ rssConf.set(RssClientConf.RSS_INDEX_READ_LIMIT, indexReadLimit);
+ rssConf.set(RssClientConf.RSS_CLIENT_READ_BUFFER_SIZE,
String.valueOf(readBufferSize));
+ init(storageType, appId, shuffleId, partitionId, indexReadLimit,
partitionNumPerRange, partitionNum, readBufferSize, storageBasePath,
blockIdBitmap, taskIdBitmap, shuffleServerInfoList, hadoopConf,
- idHelper, ShuffleDataDistributionType.NORMAL, false, false);
+ idHelper, ShuffleDataDistributionType.NORMAL, false, false, rssConf);
}
@Override
diff --git
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
index 1b072dd9..cfc58ad9 100644
---
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
+++
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
@@ -84,6 +84,7 @@ import org.apache.uniffle.common.ShuffleAssignmentsInfo;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.exception.RssFetchFailedException;
import org.apache.uniffle.common.rpc.StatusCode;
@@ -111,6 +112,7 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
private final int unregisterThreadPoolSize;
private final int unregisterRequestTimeSec;
private Set<ShuffleServerInfo> defectiveServers;
+ private RssConf rssConf;
public ShuffleWriteClientImpl(
String clientType,
@@ -125,6 +127,25 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
int dataCommitPoolSize,
int unregisterThreadPoolSize,
int unregisterRequestTimeSec) {
+ this(clientType, retryMax, retryIntervalMax, heartBeatThreadNum, replica,
replicaWrite, replicaRead,
+ replicaSkipEnabled, dataTransferPoolSize, dataCommitPoolSize,
unregisterThreadPoolSize,
+ unregisterRequestTimeSec, new RssConf());
+ }
+
+ public ShuffleWriteClientImpl(
+ String clientType,
+ int retryMax,
+ long retryIntervalMax,
+ int heartBeatThreadNum,
+ int replica,
+ int replicaWrite,
+ int replicaRead,
+ boolean replicaSkipEnabled,
+ int dataTransferPoolSize,
+ int dataCommitPoolSize,
+ int unregisterThreadPoolSize,
+ int unregisterRequestTimeSec,
+ RssConf rssConf) {
this.clientType = clientType;
this.retryMax = retryMax;
this.retryIntervalMax = retryIntervalMax;
@@ -141,6 +162,7 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
if (replica > 1) {
defectiveServers = Sets.newConcurrentHashSet();
}
+ this.rssConf = rssConf;
}
private boolean sendShuffleDataAsync(
@@ -680,7 +702,8 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
callableList.add(() -> {
try {
ShuffleServerClient client =
-
ShuffleServerClientFactory.getInstance().getShuffleServerClient(clientType,
shuffleServerInfo);
+
ShuffleServerClientFactory.getInstance().getShuffleServerClient(
+ clientType, shuffleServerInfo, rssConf);
RssAppHeartBeatResponse response = client.sendHeartBeat(request);
if (response.getStatusCode() != StatusCode.SUCCESS) {
LOG.warn("Failed to send heartbeat to " + shuffleServerInfo);
@@ -745,7 +768,8 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
callableList.add(() -> {
try {
ShuffleServerClient client =
-
ShuffleServerClientFactory.getInstance().getShuffleServerClient(clientType,
shuffleServerInfo);
+
ShuffleServerClientFactory.getInstance().getShuffleServerClient(
+ clientType, shuffleServerInfo, rssConf);
RssUnregisterShuffleResponse response =
client.unregisterShuffle(request);
if (response.getStatusCode() != StatusCode.SUCCESS) {
LOG.warn("Failed to unregister shuffle to " +
shuffleServerInfo);
@@ -798,7 +822,7 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
@VisibleForTesting
public ShuffleServerClient getShuffleServerClient(ShuffleServerInfo
shuffleServerInfo) {
- return
ShuffleServerClientFactory.getInstance().getShuffleServerClient(clientType,
shuffleServerInfo);
+ return
ShuffleServerClientFactory.getInstance().getShuffleServerClient(clientType,
shuffleServerInfo, rssConf);
}
@VisibleForTesting
diff --git
a/client/src/main/java/org/apache/uniffle/client/request/CreateShuffleReadClientRequest.java
b/client/src/main/java/org/apache/uniffle/client/request/CreateShuffleReadClientRequest.java
index a6f676fe..a56252bb 100644
---
a/client/src/main/java/org/apache/uniffle/client/request/CreateShuffleReadClientRequest.java
+++
b/client/src/main/java/org/apache/uniffle/client/request/CreateShuffleReadClientRequest.java
@@ -25,6 +25,7 @@ import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.apache.uniffle.client.util.DefaultIdHelper;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.util.IdHelper;
public class CreateShuffleReadClientRequest {
@@ -32,10 +33,7 @@ public class CreateShuffleReadClientRequest {
private String appId;
private int shuffleId;
private int partitionId;
- private String storageType;
private String basePath;
- private int indexReadLimit;
- private int readBufferSize;
private int partitionNumPerRange;
private int partitionNum;
private Roaring64NavigableMap blockIdBitmap;
@@ -45,16 +43,13 @@ public class CreateShuffleReadClientRequest {
private IdHelper idHelper;
private ShuffleDataDistributionType shuffleDataDistributionType =
ShuffleDataDistributionType.NORMAL;
private boolean expectedTaskIdsBitmapFilterEnable = false;
- private boolean offHeapEnabled = false;
+ private RssConf rssConf;
public CreateShuffleReadClientRequest(
String appId,
int shuffleId,
int partitionId,
- String storageType,
String basePath,
- int indexReadLimit,
- int readBufferSize,
int partitionNumPerRange,
int partitionNum,
Roaring64NavigableMap blockIdBitmap,
@@ -63,10 +58,10 @@ public class CreateShuffleReadClientRequest {
Configuration hadoopConf,
ShuffleDataDistributionType dataDistributionType,
boolean expectedTaskIdsBitmapFilterEnable,
- boolean offHeapEnabled) {
- this(appId, shuffleId, partitionId, storageType, basePath, indexReadLimit,
readBufferSize,
+ RssConf rssConf) {
+ this(appId, shuffleId, partitionId, basePath,
partitionNumPerRange, partitionNum, blockIdBitmap, taskIdBitmap,
shuffleServerInfoList,
- hadoopConf, new DefaultIdHelper(), expectedTaskIdsBitmapFilterEnable,
offHeapEnabled);
+ hadoopConf, new DefaultIdHelper(), expectedTaskIdsBitmapFilterEnable,
rssConf);
this.shuffleDataDistributionType = dataDistributionType;
}
@@ -74,10 +69,7 @@ public class CreateShuffleReadClientRequest {
String appId,
int shuffleId,
int partitionId,
- String storageType,
String basePath,
- int indexReadLimit,
- int readBufferSize,
int partitionNumPerRange,
int partitionNum,
Roaring64NavigableMap blockIdBitmap,
@@ -86,14 +78,30 @@ public class CreateShuffleReadClientRequest {
Configuration hadoopConf,
IdHelper idHelper,
boolean expectedTaskIdsBitmapFilterEnable,
- boolean offHeapEnabled) {
+ RssConf rssConf) {
+ this(appId, shuffleId, partitionId, basePath, partitionNumPerRange,
+ partitionNum, blockIdBitmap, taskIdBitmap, shuffleServerInfoList,
hadoopConf, idHelper,
+ expectedTaskIdsBitmapFilterEnable);
+ this.rssConf = rssConf;
+ }
+
+ public CreateShuffleReadClientRequest(
+ String appId,
+ int shuffleId,
+ int partitionId,
+ String basePath,
+ int partitionNumPerRange,
+ int partitionNum,
+ Roaring64NavigableMap blockIdBitmap,
+ Roaring64NavigableMap taskIdBitmap,
+ List<ShuffleServerInfo> shuffleServerInfoList,
+ Configuration hadoopConf,
+ IdHelper idHelper,
+ boolean expectedTaskIdsBitmapFilterEnable) {
this.appId = appId;
this.shuffleId = shuffleId;
this.partitionId = partitionId;
- this.storageType = storageType;
this.basePath = basePath;
- this.indexReadLimit = indexReadLimit;
- this.readBufferSize = readBufferSize;
this.partitionNumPerRange = partitionNumPerRange;
this.partitionNum = partitionNum;
this.blockIdBitmap = blockIdBitmap;
@@ -102,17 +110,13 @@ public class CreateShuffleReadClientRequest {
this.hadoopConf = hadoopConf;
this.idHelper = idHelper;
this.expectedTaskIdsBitmapFilterEnable = expectedTaskIdsBitmapFilterEnable;
- this.offHeapEnabled = offHeapEnabled;
}
public CreateShuffleReadClientRequest(
String appId,
int shuffleId,
int partitionId,
- String storageType,
String basePath,
- int indexReadLimit,
- int readBufferSize,
int partitionNumPerRange,
int partitionNum,
Roaring64NavigableMap blockIdBitmap,
@@ -120,10 +124,11 @@ public class CreateShuffleReadClientRequest {
List<ShuffleServerInfo> shuffleServerInfoList,
Configuration hadoopConf,
boolean expectedTaskIdsBitmapFilterEnable,
- boolean offHeapEnabled) {
- this(appId, shuffleId, partitionId, storageType, basePath, indexReadLimit,
readBufferSize,
+ RssConf rssConf) {
+ this(appId, shuffleId, partitionId, basePath,
partitionNumPerRange, partitionNum, blockIdBitmap, taskIdBitmap,
shuffleServerInfoList,
- hadoopConf, new DefaultIdHelper(), expectedTaskIdsBitmapFilterEnable,
offHeapEnabled);
+ hadoopConf, new DefaultIdHelper(), expectedTaskIdsBitmapFilterEnable);
+ this.rssConf = rssConf;
}
public String getAppId() {
@@ -146,22 +151,10 @@ public class CreateShuffleReadClientRequest {
return partitionNum;
}
- public String getStorageType() {
- return storageType;
- }
-
public String getBasePath() {
return basePath;
}
- public int getIndexReadLimit() {
- return indexReadLimit;
- }
-
- public int getReadBufferSize() {
- return readBufferSize;
- }
-
public Roaring64NavigableMap getBlockIdBitmap() {
return blockIdBitmap;
}
@@ -190,7 +183,7 @@ public class CreateShuffleReadClientRequest {
return expectedTaskIdsBitmapFilterEnable;
}
- public boolean isOffHeapEnabled() {
- return offHeapEnabled;
+ public RssConf getRssConf() {
+ return rssConf;
}
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/ShuffleIndexResult.java
b/common/src/main/java/org/apache/uniffle/common/ShuffleIndexResult.java
index 7e5d5b58..adf28e2a 100644
--- a/common/src/main/java/org/apache/uniffle/common/ShuffleIndexResult.java
+++ b/common/src/main/java/org/apache/uniffle/common/ShuffleIndexResult.java
@@ -17,20 +17,22 @@
package org.apache.uniffle.common;
+import java.nio.ByteBuffer;
+
public class ShuffleIndexResult {
- private final byte[] indexData;
+ private final ByteBuffer indexData;
private long dataFileLen;
public ShuffleIndexResult() {
- this(new byte[0], -1);
+ this(ByteBuffer.wrap(new byte[0]), -1);
}
- public ShuffleIndexResult(byte[] bytes, long dataFileLen) {
+ public ShuffleIndexResult(ByteBuffer bytes, long dataFileLen) {
this.indexData = bytes;
this.dataFileLen = dataFileLen;
}
- public byte[] getIndexData() {
+ public ByteBuffer getIndexData() {
return indexData;
}
@@ -39,6 +41,6 @@ public class ShuffleIndexResult {
}
public boolean isEmpty() {
- return indexData == null || indexData.length == 0;
+ return indexData == null || indexData.remaining() == 0;
}
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
index 1823c1c0..f7d49c1f 100644
--- a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
+++ b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
@@ -115,4 +115,21 @@ public class RssClientConf {
.booleanType()
.defaultValue(false)
.withDescription("Client can use off heap memory");
+
+ public static final ConfigOption<Integer> RSS_INDEX_READ_LIMIT =
ConfigOptions
+ .key("rss.index.read.limit")
+ .intType()
+ .defaultValue(500);
+
+ public static final ConfigOption<String> RSS_STORAGE_TYPE = ConfigOptions
+ .key("rss.storage.type")
+ .stringType()
+ .defaultValue("")
+ .withDescription("Supports MEMORY_LOCALFILE, MEMORY_HDFS,
MEMORY_LOCALFILE_HDFS");
+
+ public static final ConfigOption<String> RSS_CLIENT_READ_BUFFER_SIZE =
ConfigOptions
+ .key("rss.client.read.buffer.size")
+ .stringType()
+ .defaultValue("14m")
+ .withDescription("The max data size read from storage");
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClient.java
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClient.java
index 34ebb20a..c82408af 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClient.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClient.java
@@ -24,6 +24,7 @@ import java.util.Objects;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
+import com.google.common.util.concurrent.SettableFuture;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.util.concurrent.Future;
@@ -31,8 +32,10 @@ import io.netty.util.concurrent.GenericFutureListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.netty.handle.TransportResponseHandler;
import org.apache.uniffle.common.netty.protocol.Message;
+import org.apache.uniffle.common.netty.protocol.RpcResponse;
import org.apache.uniffle.common.util.NettyUtils;
@@ -63,7 +66,7 @@ public class TransportClient implements Closeable {
return channel.remoteAddress();
}
- public ChannelFuture sendShuffleData(Message message, RpcResponseCallback
callback) {
+ public ChannelFuture sendRpc(Message message, RpcResponseCallback callback) {
if (logger.isTraceEnabled()) {
logger.trace("Pushing data to {}", NettyUtils.getRemoteAddress(channel));
}
@@ -73,6 +76,27 @@ public class TransportClient implements Closeable {
return channel.writeAndFlush(message).addListener(listener);
}
+ public RpcResponse sendRpcSync(Message message, long timeoutMs) {
+ SettableFuture<RpcResponse> result = SettableFuture.create();
+ RpcResponseCallback callback = new RpcResponseCallback() {
+ @Override
+ public void onSuccess(RpcResponse response) {
+ result.set(response);
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ result.setException(e);
+ }
+ };
+ sendRpc(message, callback);
+ try {
+ return result.get(timeoutMs, TimeUnit.MILLISECONDS);
+ } catch (Exception e) {
+ throw new RssException(e);
+ }
+ }
+
public static long requestId() {
return counter.getAndIncrement();
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/segment/FixedSizeSegmentSplitter.java
b/common/src/main/java/org/apache/uniffle/common/segment/FixedSizeSegmentSplitter.java
index 9c1f0d23..17998b47 100644
---
a/common/src/main/java/org/apache/uniffle/common/segment/FixedSizeSegmentSplitter.java
+++
b/common/src/main/java/org/apache/uniffle/common/segment/FixedSizeSegmentSplitter.java
@@ -46,28 +46,27 @@ public class FixedSizeSegmentSplitter implements
SegmentSplitter {
return Lists.newArrayList();
}
- byte[] indexData = shuffleIndexResult.getIndexData();
+ ByteBuffer indexData = shuffleIndexResult.getIndexData();
long dataFileLen = shuffleIndexResult.getDataFileLen();
return transIndexDataToSegments(indexData, readBufferSize, dataFileLen);
}
- private static List<ShuffleDataSegment> transIndexDataToSegments(byte[]
indexData,
+ private static List<ShuffleDataSegment> transIndexDataToSegments(ByteBuffer
indexData,
int readBufferSize, long dataFileLen) {
- ByteBuffer byteBuffer = ByteBuffer.wrap(indexData);
List<BufferSegment> bufferSegments = Lists.newArrayList();
List<ShuffleDataSegment> dataFileSegments = Lists.newArrayList();
int bufferOffset = 0;
long fileOffset = -1;
long totalLength = 0;
- while (byteBuffer.hasRemaining()) {
+ while (indexData.hasRemaining()) {
try {
- final long offset = byteBuffer.getLong();
- final int length = byteBuffer.getInt();
- final int uncompressLength = byteBuffer.getInt();
- final long crc = byteBuffer.getLong();
- final long blockId = byteBuffer.getLong();
- final long taskAttemptId = byteBuffer.getLong();
+ final long offset = indexData.getLong();
+ final int length = indexData.getInt();
+ final int uncompressLength = indexData.getInt();
+ final long crc = indexData.getLong();
+ final long blockId = indexData.getLong();
+ final long taskAttemptId = indexData.getLong();
// The index file is written, read and parsed sequentially, so these
parsed index segments
// index a continuous shuffle data in the corresponding data file and
the first segment's
diff --git
a/common/src/main/java/org/apache/uniffle/common/segment/LocalOrderSegmentSplitter.java
b/common/src/main/java/org/apache/uniffle/common/segment/LocalOrderSegmentSplitter.java
index a8bcd0ce..11b16dbf 100644
---
a/common/src/main/java/org/apache/uniffle/common/segment/LocalOrderSegmentSplitter.java
+++
b/common/src/main/java/org/apache/uniffle/common/segment/LocalOrderSegmentSplitter.java
@@ -63,10 +63,9 @@ public class LocalOrderSegmentSplitter implements
SegmentSplitter {
return Lists.newArrayList();
}
- byte[] indexData = shuffleIndexResult.getIndexData();
+ ByteBuffer indexData = shuffleIndexResult.getIndexData();
long dataFileLen = shuffleIndexResult.getDataFileLen();
- ByteBuffer byteBuffer = ByteBuffer.wrap(indexData);
List<BufferSegment> bufferSegments = Lists.newArrayList();
List<ShuffleDataSegment> dataFileSegments = Lists.newArrayList();
@@ -88,14 +87,14 @@ public class LocalOrderSegmentSplitter implements
SegmentSplitter {
* 3. Single shuffleDataSegment's blocks should be continuous
*/
int index = 0;
- while (byteBuffer.hasRemaining()) {
+ while (indexData.hasRemaining()) {
try {
- long offset = byteBuffer.getLong();
- int length = byteBuffer.getInt();
- int uncompressLength = byteBuffer.getInt();
- long crc = byteBuffer.getLong();
- long blockId = byteBuffer.getLong();
- long taskAttemptId = byteBuffer.getLong();
+ long offset = indexData.getLong();
+ int length = indexData.getInt();
+ int uncompressLength = indexData.getInt();
+ long crc = indexData.getLong();
+ long blockId = indexData.getLong();
+ long taskAttemptId = indexData.getLong();
totalLen += length;
indexTaskIds.add(taskAttemptId);
diff --git
a/common/src/test/java/org/apache/uniffle/common/segment/FixedSizeSegmentSplitterTest.java
b/common/src/test/java/org/apache/uniffle/common/segment/FixedSizeSegmentSplitterTest.java
index 5655288c..53fbec43 100644
---
a/common/src/test/java/org/apache/uniffle/common/segment/FixedSizeSegmentSplitterTest.java
+++
b/common/src/test/java/org/apache/uniffle/common/segment/FixedSizeSegmentSplitterTest.java
@@ -45,7 +45,8 @@ public class FixedSizeSegmentSplitterTest {
Pair.of(10, 0)
);
- List<ShuffleDataSegment> shuffleDataSegments = splitter.split(new
ShuffleIndexResult(data, dataLength));
+ List<ShuffleDataSegment> shuffleDataSegments = splitter.split(new
ShuffleIndexResult(
+ ByteBuffer.wrap(data), dataLength));
assertEquals(1, shuffleDataSegments.size());
assertEquals(0, shuffleDataSegments.get(0).getOffset());
assertEquals(48, shuffleDataSegments.get(0).getLength());
@@ -70,7 +71,7 @@ public class FixedSizeSegmentSplitterTest {
Pair.of(32, 6),
Pair.of(6, 0)
);
- shuffleDataSegments = splitter.split(new ShuffleIndexResult(data, -1));
+ shuffleDataSegments = splitter.split(new
ShuffleIndexResult(ByteBuffer.wrap(data), -1));
assertEquals(3, shuffleDataSegments.size());
assertEquals(0, shuffleDataSegments.get(0).getOffset());
@@ -91,7 +92,7 @@ public class FixedSizeSegmentSplitterTest {
data = incompleteByteBuffer.array();
// It should throw exception
try {
- splitter.split(new ShuffleIndexResult(data, -1));
+ splitter.split(new ShuffleIndexResult(ByteBuffer.wrap(data), -1));
fail();
} catch (Exception e) {
// ignore
diff --git
a/common/src/test/java/org/apache/uniffle/common/segment/LocalOrderSegmentSplitterTest.java
b/common/src/test/java/org/apache/uniffle/common/segment/LocalOrderSegmentSplitterTest.java
index 418cb7f5..3747196b 100644
---
a/common/src/test/java/org/apache/uniffle/common/segment/LocalOrderSegmentSplitterTest.java
+++
b/common/src/test/java/org/apache/uniffle/common/segment/LocalOrderSegmentSplitterTest.java
@@ -50,7 +50,7 @@ public class LocalOrderSegmentSplitterTest {
Pair.of(8, 9)
);
List<ShuffleDataSegment> dataSegments1 = new
LocalOrderSegmentSplitter(taskIds, 1000)
- .split(new ShuffleIndexResult(data, -1));
+ .split(new ShuffleIndexResult(ByteBuffer.wrap(data), -1));
assertEquals(2, dataSegments1.size());
assertEquals(16, dataSegments1.get(0).getOffset());
@@ -87,7 +87,7 @@ public class LocalOrderSegmentSplitterTest {
Pair.of(1, 4)
);
List<ShuffleDataSegment> dataSegments2 =
- new LocalOrderSegmentSplitter(taskIds, 32).split(new
ShuffleIndexResult(data, -1));
+ new LocalOrderSegmentSplitter(taskIds, 32).split(new
ShuffleIndexResult(ByteBuffer.wrap(data), -1));
assertEquals(2, dataSegments2.size());
assertEquals(0, dataSegments2.get(0).getOffset());
assertEquals(2, dataSegments2.get(0).getLength());
@@ -109,7 +109,7 @@ public class LocalOrderSegmentSplitterTest {
Pair.of(1, 1)
);
List<ShuffleDataSegment> dataSegments3 =
- new LocalOrderSegmentSplitter(taskIds, 3).split(new
ShuffleIndexResult(data, -1));
+ new LocalOrderSegmentSplitter(taskIds, 3).split(new
ShuffleIndexResult(ByteBuffer.wrap(data), -1));
assertEquals(3, dataSegments3.size());
assertEquals(0, dataSegments3.get(0).getOffset());
assertEquals(3, dataSegments3.get(0).getLength());
@@ -133,7 +133,7 @@ public class LocalOrderSegmentSplitterTest {
Pair.of(1, 4)
);
List<ShuffleDataSegment> dataSegments4 =
- new LocalOrderSegmentSplitter(taskIds, 3).split(new
ShuffleIndexResult(data, -1));
+ new LocalOrderSegmentSplitter(taskIds, 3).split(new
ShuffleIndexResult(ByteBuffer.wrap(data), -1));
assertEquals(2, dataSegments4.size());
}
@@ -161,10 +161,10 @@ public class LocalOrderSegmentSplitterTest {
Pair.of(8, 5)
);
List<ShuffleDataSegment> dataSegments1 = new
LocalOrderSegmentSplitter(taskIds, 32)
- .split(new ShuffleIndexResult(data, realDataLength));
+ .split(new ShuffleIndexResult(ByteBuffer.wrap(data), realDataLength));
List<ShuffleDataSegment> dataSegments2 = new FixedSizeSegmentSplitter(32)
- .split(new ShuffleIndexResult(data, realDataLength));
+ .split(new ShuffleIndexResult(ByteBuffer.wrap(data), realDataLength));
checkConsistency(dataSegments1, dataSegments2);
}
@@ -217,7 +217,7 @@ public class LocalOrderSegmentSplitterTest {
Pair.of(10, 3),
Pair.of(9, 1)
);
- List<ShuffleDataSegment> dataSegments = splitter.split(new
ShuffleIndexResult(data, -1));
+ List<ShuffleDataSegment> dataSegments = splitter.split(new
ShuffleIndexResult(ByteBuffer.wrap(data), -1));
assertEquals(2, dataSegments.size());
assertEquals(32, dataSegments.get(0).getOffset());
assertEquals(56, dataSegments.get(0).getLength());
@@ -257,7 +257,7 @@ public class LocalOrderSegmentSplitterTest {
Pair.of(1, 1),
Pair.of(6, 1)
);
- dataSegments = new LocalOrderSegmentSplitter(taskIds, 32).split(new
ShuffleIndexResult(data, -1));
+ dataSegments = new LocalOrderSegmentSplitter(taskIds, 32).split(new
ShuffleIndexResult(ByteBuffer.wrap(data), -1));
assertEquals(2, dataSegments.size());
assertEquals(0, dataSegments.get(0).getOffset());
assertEquals(32, dataSegments.get(0).getLength());
@@ -286,7 +286,7 @@ public class LocalOrderSegmentSplitterTest {
Pair.of(16, 1),
Pair.of(6, 1)
);
- List<ShuffleDataSegment> dataSegments = splitter.split(new
ShuffleIndexResult(data, -1));
+ List<ShuffleDataSegment> dataSegments = splitter.split(new
ShuffleIndexResult(ByteBuffer.wrap(data), -1));
assertEquals(3, dataSegments.size());
assertEquals(0, dataSegments.get(0).getOffset());
@@ -310,7 +310,7 @@ public class LocalOrderSegmentSplitterTest {
Pair.of(16, 2),
Pair.of(6, 1)
);
- dataSegments = splitter.split(new ShuffleIndexResult(data, -1));
+ dataSegments = splitter.split(new
ShuffleIndexResult(ByteBuffer.wrap(data), -1));
assertEquals(2, dataSegments.size());
assertEquals(32, dataSegments.get(0).getOffset());
@@ -333,7 +333,7 @@ public class LocalOrderSegmentSplitterTest {
Pair.of(16, 4),
Pair.of(6, 1)
);
- dataSegments = splitter.split(new ShuffleIndexResult(data, -1));
+ dataSegments = splitter.split(new
ShuffleIndexResult(ByteBuffer.wrap(data), -1));
assertEquals(2, dataSegments.size());
assertEquals(32, dataSegments.get(0).getOffset());
@@ -353,7 +353,8 @@ public class LocalOrderSegmentSplitterTest {
Pair.of(16, 230)
);
taskIds = Roaring64NavigableMap.bitmapOf(230);
- dataSegments = new LocalOrderSegmentSplitter(taskIds, 10000).split(new
ShuffleIndexResult(data, -1));
+ dataSegments = new LocalOrderSegmentSplitter(taskIds, 10000).split(
+ new ShuffleIndexResult(ByteBuffer.wrap(data), -1));
assertEquals(2, dataSegments.size());
assertEquals(16, dataSegments.get(0).getOffset());
assertEquals(16, dataSegments.get(0).getLength());
@@ -374,7 +375,8 @@ public class LocalOrderSegmentSplitterTest {
Pair.of(1, 6)
);
taskIds = Roaring64NavigableMap.bitmapOf(2, 3, 4);
- dataSegments = new LocalOrderSegmentSplitter(taskIds, 10000).split(new
ShuffleIndexResult(data, -1));
+ dataSegments = new LocalOrderSegmentSplitter(taskIds, 10000).split(
+ new ShuffleIndexResult(ByteBuffer.wrap(data), -1));
assertEquals(2, dataSegments.size());
assertEquals(0, dataSegments.get(0).getOffset());
assertEquals(3, dataSegments.get(0).getLength());
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/factory/ShuffleServerClientFactory.java
b/internal-client/src/main/java/org/apache/uniffle/client/factory/ShuffleServerClientFactory.java
index cc0896ac..af81be65 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/factory/ShuffleServerClientFactory.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/factory/ShuffleServerClientFactory.java
@@ -21,8 +21,10 @@ import java.util.Map;
import org.apache.uniffle.client.api.ShuffleServerClient;
import org.apache.uniffle.client.impl.grpc.ShuffleServerGrpcClient;
+import org.apache.uniffle.client.impl.grpc.ShuffleServerGrpcNettyClient;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.util.JavaUtils;
public class ShuffleServerClientFactory {
@@ -41,9 +43,15 @@ public class ShuffleServerClientFactory {
return LazyHolder.INSTANCE;
}
- private ShuffleServerClient createShuffleServerClient(String clientType,
ShuffleServerInfo shuffleServerInfo) {
+ private ShuffleServerClient createShuffleServerClient(String clientType,
+ ShuffleServerInfo shuffleServerInfo, RssConf rssConf) {
if (clientType.equalsIgnoreCase(ClientType.GRPC.name())) {
return new ShuffleServerGrpcClient(shuffleServerInfo.getHost(),
shuffleServerInfo.getGrpcPort());
+ } else if (clientType.equalsIgnoreCase(ClientType.GRPC_NETTY.name())) {
+ return new ShuffleServerGrpcNettyClient(rssConf,
+ shuffleServerInfo.getHost(),
+ shuffleServerInfo.getGrpcPort(),
+ shuffleServerInfo.getNettyPort());
} else {
throw new UnsupportedOperationException("Unsupported client type " +
clientType);
}
@@ -51,10 +59,15 @@ public class ShuffleServerClientFactory {
public synchronized ShuffleServerClient getShuffleServerClient(
String clientType, ShuffleServerInfo shuffleServerInfo) {
+ return getShuffleServerClient(clientType, shuffleServerInfo, new
RssConf());
+ }
+
+ public synchronized ShuffleServerClient getShuffleServerClient(
+ String clientType, ShuffleServerInfo shuffleServerInfo, RssConf rssConf)
{
clients.putIfAbsent(clientType, JavaUtils.newConcurrentMap());
Map<ShuffleServerInfo, ShuffleServerClient> serverToClients =
clients.get(clientType);
if (serverToClients.get(shuffleServerInfo) == null) {
- serverToClients.put(shuffleServerInfo,
createShuffleServerClient(clientType, shuffleServerInfo));
+ serverToClients.put(shuffleServerInfo,
createShuffleServerClient(clientType, shuffleServerInfo, rssConf));
}
return serverToClients.get(shuffleServerInfo);
}
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
index 7e4e0faf..591a1a47 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
@@ -17,6 +17,7 @@
package org.apache.uniffle.client.impl.grpc;
+import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@@ -104,8 +105,8 @@ import
org.apache.uniffle.proto.ShuffleServerGrpc.ShuffleServerBlockingStub;
public class ShuffleServerGrpcClient extends GrpcClient implements
ShuffleServerClient {
private static final Logger LOG =
LoggerFactory.getLogger(ShuffleServerGrpcClient.class);
- private static final long FAILED_REQUIRE_ID = -1;
- private static final long RPC_TIMEOUT_DEFAULT_MS = 60000;
+ protected static final long FAILED_REQUIRE_ID = -1;
+ protected static final long RPC_TIMEOUT_DEFAULT_MS = 60000;
private long rpcTimeout = RPC_TIMEOUT_DEFAULT_MS;
private ShuffleServerBlockingStub blockingStub;
@@ -590,7 +591,7 @@ public class ShuffleServerGrpcClient extends GrpcClient
implements ShuffleServer
switch (statusCode) {
case SUCCESS:
response = new RssGetShuffleDataResponse(
- StatusCode.SUCCESS, rpcResponse.getData().toByteArray());
+ StatusCode.SUCCESS,
ByteBuffer.wrap(rpcResponse.getData().toByteArray()));
break;
default:
@@ -625,7 +626,9 @@ public class ShuffleServerGrpcClient extends GrpcClient
implements ShuffleServer
switch (statusCode) {
case SUCCESS:
response = new RssGetShuffleIndexResponse(
- StatusCode.SUCCESS, rpcResponse.getIndexData().toByteArray(),
rpcResponse.getDataFileLen());
+ StatusCode.SUCCESS,
+ ByteBuffer.wrap(rpcResponse.getIndexData().toByteArray()),
+ rpcResponse.getDataFileLen());
break;
default:
@@ -674,7 +677,7 @@ public class ShuffleServerGrpcClient extends GrpcClient
implements ShuffleServer
switch (statusCode) {
case SUCCESS:
response = new RssGetInMemoryShuffleDataResponse(
- StatusCode.SUCCESS, rpcResponse.getData().toByteArray(),
+ StatusCode.SUCCESS,
ByteBuffer.wrap(rpcResponse.getData().toByteArray()),
toBufferSegments(rpcResponse.getShuffleDataBlockSegmentsList()));
break;
default:
@@ -702,7 +705,7 @@ public class ShuffleServerGrpcClient extends GrpcClient
implements ShuffleServer
return ret;
}
- private List<BufferSegment> toBufferSegments(List<ShuffleDataBlockSegment>
blockSegments) {
+ protected List<BufferSegment> toBufferSegments(List<ShuffleDataBlockSegment>
blockSegments) {
List<BufferSegment> ret = Lists.newArrayList();
for (ShuffleDataBlockSegment sdbs : blockSegments) {
ret.add(new BufferSegment(sdbs.getBlockId(), sdbs.getOffset(),
sdbs.getLength(),
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
new file mode 100644
index 00000000..eb50665c
--- /dev/null
+++
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
@@ -0,0 +1,253 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.client.impl.grpc;
+
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicLong;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.client.request.RssGetInMemoryShuffleDataRequest;
+import org.apache.uniffle.client.request.RssGetShuffleDataRequest;
+import org.apache.uniffle.client.request.RssGetShuffleIndexRequest;
+import org.apache.uniffle.client.request.RssSendShuffleDataRequest;
+import org.apache.uniffle.client.response.RssGetInMemoryShuffleDataResponse;
+import org.apache.uniffle.client.response.RssGetShuffleDataResponse;
+import org.apache.uniffle.client.response.RssGetShuffleIndexResponse;
+import org.apache.uniffle.client.response.RssSendShuffleDataResponse;
+import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.exception.NotRetryException;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.exception.RssFetchFailedException;
+import org.apache.uniffle.common.netty.client.TransportClient;
+import org.apache.uniffle.common.netty.client.TransportClientFactory;
+import org.apache.uniffle.common.netty.client.TransportConf;
+import org.apache.uniffle.common.netty.client.TransportContext;
+import org.apache.uniffle.common.netty.protocol.GetLocalShuffleDataRequest;
+import org.apache.uniffle.common.netty.protocol.GetLocalShuffleDataResponse;
+import org.apache.uniffle.common.netty.protocol.GetLocalShuffleIndexRequest;
+import org.apache.uniffle.common.netty.protocol.GetLocalShuffleIndexResponse;
+import org.apache.uniffle.common.netty.protocol.GetMemoryShuffleDataRequest;
+import org.apache.uniffle.common.netty.protocol.GetMemoryShuffleDataResponse;
+import org.apache.uniffle.common.netty.protocol.RpcResponse;
+import org.apache.uniffle.common.netty.protocol.SendShuffleDataRequest;
+import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.common.util.RetryUtils;
+
+public class ShuffleServerGrpcNettyClient extends ShuffleServerGrpcClient {
+ private static final Logger LOG =
LoggerFactory.getLogger(ShuffleServerGrpcNettyClient.class);
+ private int nettyPort;
+ private TransportClientFactory clientFactory;
+
+ public ShuffleServerGrpcNettyClient(RssConf rssConf, String host, int
grpcPort, int nettyPort) {
+ super(host, grpcPort);
+ this.nettyPort = nettyPort;
+ TransportContext transportContext = new TransportContext(new
TransportConf(rssConf));
+ this.clientFactory = new TransportClientFactory(transportContext);
+ }
+
+ @Override
+ public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest
request) {
+ TransportClient transportClient = getTransportClient();
+ Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks =
request.getShuffleIdToBlocks();
+ boolean isSuccessful = true;
+
+ for (Map.Entry<Integer, Map<Integer, List<ShuffleBlockInfo>>> stb :
shuffleIdToBlocks.entrySet()) {
+ int shuffleId = stb.getKey();
+ int size = 0;
+ int blockNum = 0;
+ for (Map.Entry<Integer, List<ShuffleBlockInfo>> ptb :
stb.getValue().entrySet()) {
+ for (ShuffleBlockInfo sbi : ptb.getValue()) {
+ size += sbi.getSize();
+ blockNum++;
+ }
+ }
+
+ int allocateSize = size;
+ int finalBlockNum = blockNum;
+ try {
+ RetryUtils.retry(() -> {
+ long requireId = requirePreAllocation(allocateSize,
request.getRetryMax(), request.getRetryIntervalMax());
+ if (requireId == FAILED_REQUIRE_ID) {
+ throw new RssException(String.format(
+ "requirePreAllocation failed! size[%s], host[%s], port[%s]",
allocateSize, host, port));
+ }
+
+ SendShuffleDataRequest sendShuffleDataRequest = new
SendShuffleDataRequest(
+ requestId(),
+ request.getAppId(),
+ shuffleId,
+ requireId,
+ stb.getValue(),
+ System.currentTimeMillis());
+ long start = System.currentTimeMillis();
+ RpcResponse rpcResponse =
transportClient.sendRpcSync(sendShuffleDataRequest, RPC_TIMEOUT_DEFAULT_MS);
+ LOG.debug("Do sendShuffleData to {}:{} rpc cost:" +
(System.currentTimeMillis() - start)
+ + " ms for " + allocateSize + " bytes with " + finalBlockNum + "
blocks", host, port);
+ if (rpcResponse.getStatusCode() != StatusCode.SUCCESS) {
+ String msg = "Can't send shuffle data with " + finalBlockNum
+ + " blocks to " + host + ":" + port
+ + ", statusCode=" + rpcResponse.getStatusCode()
+ + ", errorMsg:" + rpcResponse.getRetMessage();
+ if (rpcResponse.getStatusCode() == StatusCode.NO_REGISTER) {
+ throw new NotRetryException(msg);
+ } else {
+ throw new RssException(msg);
+ }
+ }
+ return rpcResponse;
+ }, request.getRetryIntervalMax(), maxRetryAttempts);
+ } catch (Throwable throwable) {
+ LOG.warn(throwable.getMessage());
+ isSuccessful = false;
+ break;
+ }
+ }
+
+ RssSendShuffleDataResponse response;
+ if (isSuccessful) {
+ response = new RssSendShuffleDataResponse(StatusCode.SUCCESS);
+ } else {
+ response = new RssSendShuffleDataResponse(StatusCode.INTERNAL_ERROR);
+ }
+ return response;
+ }
+
+ @Override
+ public RssGetInMemoryShuffleDataResponse getInMemoryShuffleData(
+ RssGetInMemoryShuffleDataRequest request) {
+ TransportClient transportClient = getTransportClient();
+ GetMemoryShuffleDataRequest getMemoryShuffleDataRequest = new
GetMemoryShuffleDataRequest(
+ requestId(),
+ request.getAppId(),
+ request.getShuffleId(),
+ request.getPartitionId(),
+ request.getLastBlockId(),
+ request.getReadBufferSize(),
+ System.currentTimeMillis(),
+ request.getExpectedTaskIds()
+ );
+ String requestInfo = "appId[" + request.getAppId()
+ + "], shuffleId[" + request.getShuffleId()
+ + "], partitionId[" + request.getPartitionId()
+ + "], lastBlockId[" + request.getLastBlockId() + "]";
+ RpcResponse rpcResponse =
transportClient.sendRpcSync(getMemoryShuffleDataRequest,
RPC_TIMEOUT_DEFAULT_MS);
+ GetMemoryShuffleDataResponse getMemoryShuffleDataResponse =
(GetMemoryShuffleDataResponse) rpcResponse;
+ StatusCode statusCode = rpcResponse.getStatusCode();
+ switch (statusCode) {
+ case SUCCESS:
+ return new RssGetInMemoryShuffleDataResponse(
+ StatusCode.SUCCESS,
getMemoryShuffleDataResponse.getData().nioBuffer(),
+ getMemoryShuffleDataResponse.getBufferSegments());
+ default:
+ String msg = "Can't get shuffle in memory data from " + host + ":" +
port
+ + " for " + requestInfo + ", errorMsg:" +
getMemoryShuffleDataResponse.getRetMessage();
+ LOG.error(msg);
+ throw new RssFetchFailedException(msg);
+ }
+ }
+
+ @Override
+ public RssGetShuffleIndexResponse getShuffleIndex(RssGetShuffleIndexRequest
request) {
+ TransportClient transportClient = getTransportClient();
+ GetLocalShuffleIndexRequest getLocalShuffleIndexRequest = new
GetLocalShuffleIndexRequest(
+ requestId(),
+ request.getAppId(),
+ request.getShuffleId(),
+ request.getPartitionId(),
+ request.getPartitionNumPerRange(),
+ request.getPartitionNum()
+ );
+ long start = System.currentTimeMillis();
+ RpcResponse rpcResponse =
transportClient.sendRpcSync(getLocalShuffleIndexRequest,
RPC_TIMEOUT_DEFAULT_MS);
+ String requestInfo = "appId[" + request.getAppId()
+ + "], shuffleId[" + request.getShuffleId()
+ + "], partitionId[" + request.getPartitionId();
+ LOG.info("GetShuffleIndex from {}:{} for {} cost {} ms", host, port,
+ requestInfo, System.currentTimeMillis() - start);
+ GetLocalShuffleIndexResponse getLocalShuffleIndexResponse =
(GetLocalShuffleIndexResponse) rpcResponse;
+ StatusCode statusCode = rpcResponse.getStatusCode();
+ switch (statusCode) {
+ case SUCCESS:
+ return new RssGetShuffleIndexResponse(
+ StatusCode.SUCCESS,
+ getLocalShuffleIndexResponse.getIndexData().nioBuffer(),
+ getLocalShuffleIndexResponse.getFileLength());
+ default:
+ String msg = "Can't get shuffle index from " + host + ":" + port
+ + " for " + requestInfo + ", errorMsg:" +
getLocalShuffleIndexResponse.getRetMessage();
+ LOG.error(msg);
+ throw new RssFetchFailedException(msg);
+ }
+ }
+
+ @Override
+ public RssGetShuffleDataResponse getShuffleData(RssGetShuffleDataRequest
request) {
+ TransportClient transportClient = getTransportClient();
+ GetLocalShuffleDataRequest getLocalShuffleIndexRequest = new
GetLocalShuffleDataRequest(
+ requestId(),
+ request.getAppId(),
+ request.getShuffleId(),
+ request.getPartitionId(),
+ request.getPartitionNumPerRange(),
+ request.getPartitionNum(),
+ request.getOffset(),
+ request.getLength(),
+ System.currentTimeMillis()
+ );
+ long start = System.currentTimeMillis();
+ RpcResponse rpcResponse =
transportClient.sendRpcSync(getLocalShuffleIndexRequest,
RPC_TIMEOUT_DEFAULT_MS);
+ String requestInfo = "appId[" + request.getAppId() + "], shuffleId["
+ + request.getShuffleId() + "], partitionId[" +
request.getPartitionId() + "]";
+ LOG.info("GetShuffleData from {}:{} for {} cost {} ms", host, port,
requestInfo,
+ System.currentTimeMillis() - start);
+ GetLocalShuffleDataResponse getLocalShuffleDataResponse =
(GetLocalShuffleDataResponse) rpcResponse;
+ StatusCode statusCode = rpcResponse.getStatusCode();
+ switch (statusCode) {
+ case SUCCESS:
+ return new RssGetShuffleDataResponse(
+ StatusCode.SUCCESS,
+ getLocalShuffleDataResponse.getData().nioBuffer());
+ default:
+ String msg = "Can't get shuffle data from " + host + ":" + port
+ + " for " + requestInfo + ", errorMsg:" +
getLocalShuffleDataResponse.getRetMessage();
+ LOG.error(msg);
+ throw new RssFetchFailedException(msg);
+ }
+ }
+
+ private static final AtomicLong counter = new AtomicLong();
+
+ public static long requestId() {
+ return counter.getAndIncrement();
+ }
+
+ private TransportClient getTransportClient() {
+ TransportClient transportClient;
+ try {
+ transportClient = clientFactory.createClient(host, nettyPort);
+ } catch (Exception e) {
+ throw new RssException("create transport client failed", e);
+ }
+ return transportClient;
+ }
+
+}
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetInMemoryShuffleDataResponse.java
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetInMemoryShuffleDataResponse.java
index 70f218d5..1468d510 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetInMemoryShuffleDataResponse.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetInMemoryShuffleDataResponse.java
@@ -17,6 +17,7 @@
package org.apache.uniffle.client.response;
+import java.nio.ByteBuffer;
import java.util.List;
import org.apache.uniffle.common.BufferSegment;
@@ -24,17 +25,17 @@ import org.apache.uniffle.common.rpc.StatusCode;
public class RssGetInMemoryShuffleDataResponse extends ClientResponse {
- private final byte[] data;
+ private final ByteBuffer data;
private final List<BufferSegment> bufferSegments;
public RssGetInMemoryShuffleDataResponse(
- StatusCode statusCode, byte[] data, List<BufferSegment> bufferSegments) {
+ StatusCode statusCode, ByteBuffer data, List<BufferSegment>
bufferSegments) {
super(statusCode);
this.bufferSegments = bufferSegments;
this.data = data;
}
- public byte[] getData() {
+ public ByteBuffer getData() {
return data;
}
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleDataResponse.java
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleDataResponse.java
index 86c54f45..33e83a34 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleDataResponse.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleDataResponse.java
@@ -17,18 +17,20 @@
package org.apache.uniffle.client.response;
+import java.nio.ByteBuffer;
+
import org.apache.uniffle.common.rpc.StatusCode;
public class RssGetShuffleDataResponse extends ClientResponse {
- private final byte[] shuffleData;
+ private final ByteBuffer shuffleData;
- public RssGetShuffleDataResponse(StatusCode statusCode, byte[] data) {
+ public RssGetShuffleDataResponse(StatusCode statusCode, ByteBuffer data) {
super(statusCode);
this.shuffleData = data;
}
- public byte[] getShuffleData() {
+ public ByteBuffer getShuffleData() {
return shuffleData;
}
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleIndexResponse.java
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleIndexResponse.java
index f1c3171f..ee1708cc 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleIndexResponse.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleIndexResponse.java
@@ -17,13 +17,15 @@
package org.apache.uniffle.client.response;
+import java.nio.ByteBuffer;
+
import org.apache.uniffle.common.ShuffleIndexResult;
import org.apache.uniffle.common.rpc.StatusCode;
public class RssGetShuffleIndexResponse extends ClientResponse {
private final ShuffleIndexResult shuffleIndexResult;
- public RssGetShuffleIndexResponse(StatusCode statusCode, byte[] data, long
dataFileLen) {
+ public RssGetShuffleIndexResponse(StatusCode statusCode, ByteBuffer data,
long dataFileLen) {
super(statusCode);
this.shuffleIndexResult = new ShuffleIndexResult(data, dataFileLen);
}
diff --git
a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
index b14873d1..4bd86714 100644
---
a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
+++
b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
@@ -17,6 +17,7 @@
package org.apache.uniffle.server;
+import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -603,14 +604,14 @@ public class ShuffleServerGrpcService extends
ShuffleServerImplBase {
appId, shuffleId, partitionId, partitionNumPerRange, partitionNum);
long readTime = System.currentTimeMillis() - start;
- byte[] data = shuffleIndexResult.getIndexData();
- ShuffleServerMetrics.counterTotalReadDataSize.inc(data.length);
-
ShuffleServerMetrics.counterTotalReadLocalIndexFileSize.inc(data.length);
+ ByteBuffer data = shuffleIndexResult.getIndexData();
+ ShuffleServerMetrics.counterTotalReadDataSize.inc(data.remaining());
+
ShuffleServerMetrics.counterTotalReadLocalIndexFileSize.inc(data.remaining());
GetLocalShuffleIndexResponse.Builder builder =
GetLocalShuffleIndexResponse.newBuilder()
.setStatus(status.toProto())
.setRetMsg(msg);
LOG.info("Successfully getShuffleIndex cost {} ms for {}"
- + " bytes with {}", readTime, data.length, requestInfo);
+ + " bytes with {}", readTime, data.remaining(), requestInfo);
builder.setIndexData(UnsafeByteOperations.unsafeWrap(data));
builder.setDataFileLen(shuffleIndexResult.getDataFileLen());
diff --git
a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
index efcc8fd6..2039cdc5 100644
---
a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
+++
b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
@@ -58,7 +58,6 @@ public class ShuffleHandlerFactory {
return INSTANCE;
}
-
public ClientReadHandler
createShuffleReadHandler(CreateShuffleReadHandlerRequest request) {
if (CollectionUtils.isEmpty(request.getShuffleServerInfoList())) {
throw new RssException("Shuffle servers should not be empty!");
@@ -118,7 +117,7 @@ public class ShuffleHandlerFactory {
private ClientReadHandler
getMemoryClientReadHandler(CreateShuffleReadHandlerRequest request,
ShuffleServerInfo ssi) {
ShuffleServerClient shuffleServerClient =
ShuffleServerClientFactory.getInstance().getShuffleServerClient(
- ClientType.GRPC.name(), ssi);
+ ClientType.GRPC.name(), ssi, request.getClientConf());
Roaring64NavigableMap expectTaskIds = null;
if (request.isExpectedTaskIdsBitmapFilterEnable()) {
Roaring64NavigableMap realExceptBlockIds =
RssUtils.cloneBitMap(request.getExpectBlockIds());
@@ -139,7 +138,7 @@ public class ShuffleHandlerFactory {
private ClientReadHandler
getLocalfileClientReaderHandler(CreateShuffleReadHandlerRequest request,
ShuffleServerInfo
ssi) {
ShuffleServerClient shuffleServerClient =
ShuffleServerClientFactory.getInstance().getShuffleServerClient(
- ClientType.GRPC.name(), ssi);
+ ClientType.GRPC.name(), ssi, request.getClientConf());
return new LocalFileClientReadHandler(
request.getAppId(), request.getShuffleId(), request.getPartitionId(),
request.getIndexReadLimit(), request.getPartitionNumPerRange(),
request.getPartitionNum(),
diff --git
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HdfsShuffleReadHandler.java
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HdfsShuffleReadHandler.java
index 0be589b6..8e84f963 100644
---
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HdfsShuffleReadHandler.java
+++
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HdfsShuffleReadHandler.java
@@ -95,7 +95,7 @@ public class HdfsShuffleReadHandler extends
DataSkippableReadHandler {
}
long dateFileLen = getDataFileLen();
LOG.info("Read index files {}.index for {} ms", filePrefix,
System.currentTimeMillis() - start);
- return new ShuffleIndexResult(indexData, dateFileLen);
+ return new ShuffleIndexResult(ByteBuffer.wrap(indexData), dateFileLen);
} catch (Exception e) {
LOG.info("Fail to read index files {}.index", filePrefix, e);
}
diff --git
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandler.java
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandler.java
index 5df89b9f..66c5f954 100644
---
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandler.java
+++
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandler.java
@@ -19,6 +19,7 @@ package org.apache.uniffle.storage.handler.impl;
import java.io.File;
import java.io.FilenameFilter;
+import java.nio.ByteBuffer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -150,7 +151,7 @@ public class LocalFileServerReadHandler implements
ServerReadHandler {
byte[] indexData = reader.read(0, len);
// get dataFileSize for read segment generation in
DataSkippableReadHandler#readShuffleData
long dataFileSize = new File(dataFileName).length();
- return new ShuffleIndexResult(indexData, dataFileSize);
+ return new ShuffleIndexResult(ByteBuffer.wrap(indexData), dataFileSize);
} catch (Exception e) {
LOG.error("Fail to read index file {} indexNum {} len {}",
indexFileName, indexNum, len);
diff --git
a/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
b/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
index a3aaab72..7ff8398e 100644
---
a/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
+++
b/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
@@ -25,6 +25,7 @@ import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssBaseConf;
+import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.util.IdHelper;
public class CreateShuffleReadHandlerRequest {
@@ -47,6 +48,7 @@ public class CreateShuffleReadHandlerRequest {
private Roaring64NavigableMap expectTaskIds;
private boolean expectedTaskIdsBitmapFilterEnable;
private boolean offHeapEnabled;
+ private RssConf clientConf;
private IdHelper idHelper;
@@ -204,4 +206,12 @@ public class CreateShuffleReadHandlerRequest {
public boolean isOffHeapEnabled() {
return offHeapEnabled;
}
+
+ public RssConf getClientConf() {
+ return clientConf;
+ }
+
+ public void setClientConf(RssConf clientConf) {
+ this.clientConf = clientConf;
+ }
}
diff --git
a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HdfsClientReadHandlerTest.java
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HdfsClientReadHandlerTest.java
index 2662a61f..1688cf65 100644
---
a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HdfsClientReadHandlerTest.java
+++
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HdfsClientReadHandlerTest.java
@@ -92,7 +92,7 @@ public class HdfsClientReadHandlerTest extends HdfsTestBase {
readBufferSize, expectBlockIds, processBlockIds, hadoopConf);
try {
ShuffleIndexResult indexResult = indexReader.readShuffleIndex();
- assertEquals(0, indexResult.getIndexData().length %
FileBasedShuffleSegment.SEGMENT_SIZE);
+ assertEquals(0, indexResult.getIndexData().remaining() %
FileBasedShuffleSegment.SEGMENT_SIZE);
} catch (Exception e) {
fail();
}
diff --git
a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java
index 5f588154..025cc124 100644
---
a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java
+++
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java
@@ -63,6 +63,7 @@ public class LocalFileServerReadHandlerTest {
LocalFileHandlerTestBase.writeIndex(byteBuffer, segment);
}
}, expectedData, new HashSet<>());
+ byteBuffer.rewind();
blocks.forEach(block -> expectBlockIds.addLong(block.getBlockId()));
@@ -74,7 +75,7 @@ public class LocalFileServerReadHandlerTest {
int actualWriteDataBlock = expectTotalBlockNum - 1;
int actualFileLen = blockSize * actualWriteDataBlock;
RssGetShuffleIndexResponse response = new
RssGetShuffleIndexResponse(StatusCode.SUCCESS,
- byteBuffer.array(), actualFileLen);
+ byteBuffer, actualFileLen);
Mockito.doReturn(response).when(mockShuffleServerClient).getShuffleIndex(Mockito.any());
int readBufferSize = 13;
@@ -91,9 +92,9 @@ public class LocalFileServerReadHandlerTest {
ArgumentMatcher<RssGetShuffleDataRequest> segment2Match =
(request) -> request.getOffset() == bytesPerSegment &&
request.getLength() == blockSize;
RssGetShuffleDataResponse segment1Response =
- new RssGetShuffleDataResponse(StatusCode.SUCCESS, segments.get(0));
+ new RssGetShuffleDataResponse(StatusCode.SUCCESS,
ByteBuffer.wrap(segments.get(0)));
RssGetShuffleDataResponse segment2Response =
- new RssGetShuffleDataResponse(StatusCode.SUCCESS, segments.get(1));
+ new RssGetShuffleDataResponse(StatusCode.SUCCESS,
ByteBuffer.wrap(segments.get(1)));
Mockito.doReturn(segment1Response).when(mockShuffleServerClient).getShuffleData(Mockito.argThat(segment1Match));
Mockito.doReturn(segment2Response).when(mockShuffleServerClient).getShuffleData(Mockito.argThat(segment2Match));