This is an automated email from the ASF dual-hosted git repository.
jerrylei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new e9a62eeaf [#825][part-3] feat(spark): Get the ShuffleServer
corresponding to the partition from ShuffleManager. (#1141)
e9a62eeaf is described below
commit e9a62eeafbced0eb38f95f35b7fb885bf0c47e99
Author: yl09099 <[email protected]>
AuthorDate: Mon Oct 30 19:08:35 2023 +0800
[#825][part-3] feat(spark): Get the ShuffleServer corresponding to the
partition from ShuffleManager. (#1141)
### What changes were proposed in this pull request?
ShuffleReader and ShuffleWriter get the ShuffleServer corresponding to the
partition from ShuffleManager
### Why are the changes needed?
Fix: #825
### Does this PR introduce any user-facing change?
No.
### How was this patch tested?
Existing UT.
---
.../manager/RssShuffleManagerInterface.java | 9 ++
.../shuffle/manager/ShuffleManagerGrpcService.java | 52 +++++++++++
.../shuffle/manager/DummyRssShuffleManager.java | 7 ++
.../apache/spark/shuffle/RssShuffleManager.java | 98 ++++++++++++++++++--
.../spark/shuffle/reader/RssShuffleReader.java | 7 +-
.../spark/shuffle/writer/RssShuffleWriter.java | 26 ++++--
.../spark/shuffle/reader/RssShuffleReaderTest.java | 3 +-
.../spark/shuffle/writer/RssShuffleWriterTest.java | 20 +++-
.../apache/spark/shuffle/RssShuffleManager.java | 101 +++++++++++++++++++--
.../spark/shuffle/reader/RssShuffleReader.java | 5 +-
.../spark/shuffle/writer/RssShuffleWriter.java | 28 ++++--
.../spark/shuffle/reader/RssShuffleReaderTest.java | 9 +-
.../spark/shuffle/writer/RssShuffleWriterTest.java | 26 +++++-
.../apache/uniffle/common/ShuffleServerInfo.java | 37 ++++++++
.../apache/uniffle/test/RSSStageResubmitTest.java | 3 +
.../uniffle/client/api/ShuffleManagerClient.java | 11 +++
.../client/impl/grpc/ShuffleManagerGrpcClient.java | 14 +++
.../RssPartitionToShuffleServerRequest.java} | 25 +++--
.../RssPartitionToShuffleServerResponse.java | 101 +++++++++++++++++++++
proto/src/main/proto/Rss.proto | 18 ++++
20 files changed, 544 insertions(+), 56 deletions(-)
diff --git
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
index 4308602dc..34009cbad 100644
---
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
+++
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
@@ -18,6 +18,7 @@
package org.apache.uniffle.shuffle.manager;
import org.apache.spark.SparkException;
+import org.apache.spark.shuffle.ShuffleHandleInfo;
/**
* This is a proxy interface that mainly delegates the un-registration of
shuffles to the
@@ -54,4 +55,12 @@ public interface RssShuffleManagerInterface {
* @throws SparkException
*/
void unregisterAllMapOutput(int shuffleId) throws SparkException;
+
+ /**
+ * Get ShuffleHandleInfo with ShuffleId
+ *
+ * @param shuffleId
+ * @return ShuffleHandleInfo
+ */
+ ShuffleHandleInfo getShuffleHandleInfoByShuffleId(int shuffleId);
}
diff --git
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
index 6dfd52e24..5c4b8795c 100644
---
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
+++
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
@@ -18,14 +18,18 @@
package org.apache.uniffle.shuffle.manager;
import java.util.Arrays;
+import java.util.List;
import java.util.Map;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.Supplier;
import io.grpc.stub.StreamObserver;
+import org.apache.spark.shuffle.ShuffleHandleInfo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.uniffle.common.RemoteStorageInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.proto.RssProtos;
import org.apache.uniffle.proto.ShuffleManagerGrpc.ShuffleManagerImplBase;
@@ -101,6 +105,54 @@ public class ShuffleManagerGrpcService extends
ShuffleManagerImplBase {
responseObserver.onCompleted();
}
+ @Override
+ public void getPartitionToShufflerServer(
+ RssProtos.PartitionToShuffleServerRequest request,
+ StreamObserver<RssProtos.PartitionToShuffleServerResponse>
responseObserver) {
+ RssProtos.PartitionToShuffleServerResponse reply;
+ RssProtos.StatusCode code;
+ int shuffleId = request.getShuffleId();
+ ShuffleHandleInfo shuffleHandleInfoByShuffleId =
+ shuffleManager.getShuffleHandleInfoByShuffleId(shuffleId);
+ if (shuffleHandleInfoByShuffleId != null) {
+ code = RssProtos.StatusCode.SUCCESS;
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers =
+ shuffleHandleInfoByShuffleId.getPartitionToServers();
+ Map<Integer, RssProtos.GetShuffleServerListResponse>
protopartitionToServers =
+ JavaUtils.newConcurrentMap();
+ for (Map.Entry<Integer, List<ShuffleServerInfo>> integerListEntry :
+ partitionToServers.entrySet()) {
+ List<RssProtos.ShuffleServerId> shuffleServerIds =
+ ShuffleServerInfo.toProto(integerListEntry.getValue());
+ RssProtos.GetShuffleServerListResponse getShuffleServerListResponse =
+ RssProtos.GetShuffleServerListResponse.newBuilder()
+ .addAllServers(shuffleServerIds)
+ .build();
+ protopartitionToServers.put(integerListEntry.getKey(),
getShuffleServerListResponse);
+ }
+ RemoteStorageInfo remoteStorage =
shuffleHandleInfoByShuffleId.getRemoteStorage();
+ RssProtos.RemoteStorageInfo.Builder protosRemoteStage =
+ RssProtos.RemoteStorageInfo.newBuilder()
+ .setPath(remoteStorage.getPath())
+ .putAllConfItems(remoteStorage.getConfItems());
+ reply =
+ RssProtos.PartitionToShuffleServerResponse.newBuilder()
+ .setStatus(code)
+ .putAllPartitionToShuffleServer(protopartitionToServers)
+ .setRemoteStorageInfo(protosRemoteStage)
+ .build();
+ } else {
+ code = RssProtos.StatusCode.INVALID_REQUEST;
+ reply =
+ RssProtos.PartitionToShuffleServerResponse.newBuilder()
+ .setStatus(code)
+ .putAllPartitionToShuffleServer(null)
+ .build();
+ }
+ responseObserver.onNext(reply);
+ responseObserver.onCompleted();
+ }
+
/**
* Remove the no longer used shuffle id's rss shuffle status. This is called
when ShuffleManager
* unregisters the corresponding shuffle id.
diff --git
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
index 9e06da4e6..dfd4b69a1 100644
---
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
+++
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
@@ -20,6 +20,8 @@ package org.apache.uniffle.shuffle.manager;
import java.util.LinkedHashSet;
import java.util.Set;
+import org.apache.spark.shuffle.ShuffleHandleInfo;
+
public class DummyRssShuffleManager implements RssShuffleManagerInterface {
public Set<Integer> unregisteredShuffleIds = new LinkedHashSet<>();
@@ -47,4 +49,9 @@ public class DummyRssShuffleManager implements
RssShuffleManagerInterface {
public void unregisterAllMapOutput(int shuffleId) {
unregisteredShuffleIds.add(shuffleId);
}
+
+ @Override
+ public ShuffleHandleInfo getShuffleHandleInfoByShuffleId(int shuffleId) {
+ return null;
+ }
}
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 4efde2ec5..dcebe11f8 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -33,7 +33,6 @@ import scala.collection.Iterator;
import scala.collection.Seq;
import com.google.common.annotations.VisibleForTesting;
-import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.ShuffleDependency;
@@ -52,14 +51,21 @@ import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
+import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
+import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
+import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse;
import org.apache.uniffle.client.util.ClientUtils;
+import org.apache.uniffle.client.util.RssClientConfig;
+import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.PartitionRange;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleAssignmentsInfo;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.exception.RssFetchFailedException;
@@ -104,10 +110,19 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
private DataPusher dataPusher;
private final int maxConcurrencyPerPartitionToWrite;
- private final Map<Integer, Integer> shuffleIdToPartitionNum =
Maps.newConcurrentMap();
- private final Map<Integer, Integer> shuffleIdToNumMapTasks =
Maps.newConcurrentMap();
+ private final Map<Integer, Integer> shuffleIdToPartitionNum =
JavaUtils.newConcurrentMap();
+ private final Map<Integer, Integer> shuffleIdToNumMapTasks =
JavaUtils.newConcurrentMap();
private GrpcServer shuffleManagerServer;
private ShuffleManagerGrpcService service;
+ private ShuffleManagerClient shuffleManagerClient;
+ /**
+ * Mapping between ShuffleId and ShuffleServer list. ShuffleServer list is
dynamically allocated.
+ * ShuffleServer is not obtained from RssShuffleHandle, but from this
mapping.
+ */
+ private Map<Integer, ShuffleHandleInfo> shuffleIdToShuffleHandleInfo =
+ JavaUtils.newConcurrentMap();
+ /** Whether to enable the dynamic shuffleServer function rewrite and reread
functions */
+ private boolean rssResubmitStage;
public RssShuffleManager(SparkConf sparkConf, boolean isDriver) {
if (sparkConf.getBoolean("spark.sql.adaptive.enabled", false)) {
@@ -183,12 +198,14 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
// shuffle cluster, we don't need shuffle data locality
sparkConf.set("spark.shuffle.reduceLocality.enabled", "false");
LOG.info("Disable shuffle data locality in RssShuffleManager.");
+ this.rssResubmitStage =
+ rssConf.getBoolean(RssClientConfig.RSS_RESUBMIT_STAGE, false)
+ && RssSparkShuffleUtils.isStageResubmitSupported();
if (!sparkConf.getBoolean(RssSparkConfig.RSS_TEST_FLAG.key(), false)) {
if (isDriver) {
heartBeatScheduledExecutorService =
ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat");
- if (sparkConf.get(RssSparkConfig.RSS_RESUBMIT_STAGE)
- && RssSparkShuffleUtils.isStageResubmitSupported()) {
+ if (rssResubmitStage) {
LOG.info("stage resubmit is supported and enabled");
// start shuffle manager server
rssConf.set(RPC_SERVER_PORT, 0);
@@ -330,6 +347,11 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
shuffleIdToPartitionNum.putIfAbsent(shuffleId,
dependency.partitioner().numPartitions());
shuffleIdToNumMapTasks.putIfAbsent(shuffleId,
dependency.rdd().partitions().length);
+ if (rssResubmitStage) {
+ ShuffleHandleInfo handleInfo =
+ new ShuffleHandleInfo(shuffleId, partitionToServers, remoteStorage);
+ shuffleIdToShuffleHandleInfo.put(shuffleId, handleInfo);
+ }
Broadcast<ShuffleHandleInfo> hdlInfoBd =
RssSparkShuffleUtils.broadcastShuffleHdlInfo(
RssSparkShuffleUtils.getActiveSparkContext(),
@@ -421,6 +443,15 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
int shuffleId = rssHandle.getShuffleId();
String taskId = "" + context.taskAttemptId() + "_" +
context.attemptNumber();
+ ShuffleHandleInfo shuffleHandleInfo;
+ if (rssResubmitStage) {
+ // Get the ShuffleServer list from the Driver based on the shuffleId
+ shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
+ } else {
+ shuffleHandleInfo =
+ new ShuffleHandleInfo(
+ shuffleId, rssHandle.getPartitionToServers(),
rssHandle.getRemoteStorage());
+ }
ShuffleWriteMetrics writeMetrics =
context.taskMetrics().shuffleWriteMetrics();
return new RssShuffleWriter<>(
rssHandle.getAppId(),
@@ -433,7 +464,8 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
shuffleWriteClient,
rssHandle,
this::markFailedTask,
- context);
+ context,
+ shuffleHandleInfo);
} else {
throw new RssException("Unexpected ShuffleHandle:" +
handle.getClass().getName());
}
@@ -463,8 +495,19 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
+ startPartition
+ "]");
start = System.currentTimeMillis();
+ ShuffleHandleInfo shuffleHandleInfo;
+ if (rssResubmitStage) {
+ // Get the ShuffleServer list from the Driver based on the shuffleId
+ shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
+ } else {
+ shuffleHandleInfo =
+ new ShuffleHandleInfo(
+ shuffleId,
+ rssShuffleHandle.getPartitionToServers(),
+ rssShuffleHandle.getRemoteStorage());
+ }
Map<Integer, List<ShuffleServerInfo>> partitionToServers =
- rssShuffleHandle.getPartitionToServers();
+ shuffleHandleInfo.getPartitionToServers();
Roaring64NavigableMap blockIdBitmap =
getShuffleResult(
clientType,
@@ -501,7 +544,8 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
partitionNum,
blockIdBitmap,
taskIdBitmap,
- RssSparkConfig.toRssConf(sparkConf));
+ RssSparkConfig.toRssConf(sparkConf),
+ partitionToServers);
} else {
throw new RssException("Unexpected ShuffleHandle:" +
handle.getClass().getName());
}
@@ -712,4 +756,42 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
}
return result;
}
+
+ @Override
+ public ShuffleHandleInfo getShuffleHandleInfoByShuffleId(int shuffleId) {
+ return shuffleIdToShuffleHandleInfo.get(shuffleId);
+ }
+
+ private ShuffleManagerClient createShuffleManagerClient(String host, int
port) {
+ // Host can be inferred from `spark.driver.bindAddress`, which would be
set when SparkContext is
+ // constructed.
+ return ShuffleManagerClientFactory.getInstance()
+ .createShuffleManagerClient(ClientType.GRPC, host, port);
+ }
+
+ /**
+ * Get the ShuffleServer list from the Driver based on the shuffleId
+ *
+ * @param shuffleId shuffleId
+ * @return ShuffleHandleInfo
+ */
+ private ShuffleHandleInfo getRemoteShuffleHandleInfo(int shuffleId) {
+ ShuffleHandleInfo shuffleHandleInfo;
+ RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
+ String driver = rssConf.getString("driver.host", "");
+ int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
+ if (shuffleManagerClient == null) {
+ shuffleManagerClient = createShuffleManagerClient(driver, port);
+ }
+ RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest =
+ new RssPartitionToShuffleServerRequest(shuffleId);
+ RssPartitionToShuffleServerResponse rpcPartitionToShufflerServer =
+
shuffleManagerClient.getPartitionToShufflerServer(rssPartitionToShuffleServerRequest);
+ shuffleHandleInfo =
+ new ShuffleHandleInfo(
+ shuffleId,
+ rpcPartitionToShufflerServer.getPartitionToServers(),
+ rpcPartitionToShufflerServer.getRemoteStorageInfo());
+ return shuffleHandleInfo;
+ }
}
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index 9130bc587..75855ba7a 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -18,6 +18,7 @@
package org.apache.spark.shuffle.reader;
import java.util.List;
+import java.util.Map;
import scala.Function0;
import scala.Option;
@@ -83,7 +84,8 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
int partitionNum,
Roaring64NavigableMap blockIdBitmap,
Roaring64NavigableMap taskIdBitmap,
- RssConf rssConf) {
+ RssConf rssConf,
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers) {
this.appId = rssShuffleHandle.getAppId();
this.startPartition = startPartition;
this.endPartition = endPartition;
@@ -98,8 +100,7 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
this.blockIdBitmap = blockIdBitmap;
this.taskIdBitmap = taskIdBitmap;
this.hadoopConf = hadoopConf;
- this.shuffleServerInfoList =
- (List<ShuffleServerInfo>)
(rssShuffleHandle.getPartitionToServers().get(startPartition));
+ this.shuffleServerInfoList = (List<ShuffleServerInfo>)
(partitionToServers.get(startPartition));
this.rssConf = rssConf;
expectedTaskIdsBitmapFilterEnable = shuffleServerInfoList.size() > 1;
}
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 11f2dd3ba..fabac5651 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -50,6 +50,7 @@ import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.shuffle.RssShuffleHandle;
import org.apache.spark.shuffle.RssShuffleManager;
import org.apache.spark.shuffle.RssSparkConfig;
+import org.apache.spark.shuffle.ShuffleHandleInfo;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.storage.BlockManagerId;
import org.slf4j.Logger;
@@ -89,6 +90,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
private boolean isMemoryShuffleEnabled;
private final Function<String, Boolean> taskFailureCallback;
private final Set<Long> blockIds = Sets.newConcurrentHashSet();
+ private TaskContext taskContext;
public RssShuffleWriter(
String appId,
@@ -100,7 +102,9 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
RssShuffleManager shuffleManager,
SparkConf sparkConf,
ShuffleWriteClient shuffleWriteClient,
- RssShuffleHandle<K, V, C> rssHandle) {
+ RssShuffleHandle<K, V, C> rssHandle,
+ ShuffleHandleInfo shuffleHandleInfo,
+ TaskContext context) {
this(
appId,
shuffleId,
@@ -111,7 +115,9 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
sparkConf,
shuffleWriteClient,
rssHandle,
- (tid) -> true);
+ (tid) -> true,
+ shuffleHandleInfo,
+ context);
this.bufferManager = bufferManager;
}
@@ -125,7 +131,9 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
SparkConf sparkConf,
ShuffleWriteClient shuffleWriteClient,
RssShuffleHandle<K, V, C> rssHandle,
- Function<String, Boolean> taskFailureCallback) {
+ Function<String, Boolean> taskFailureCallback,
+ ShuffleHandleInfo shuffleHandleInfo,
+ TaskContext context) {
this.appId = appId;
this.shuffleId = shuffleId;
this.taskId = taskId;
@@ -141,11 +149,12 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
this.bitmapSplitNum =
sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM);
this.partitionToBlockIds = Maps.newHashMap();
this.shuffleWriteClient = shuffleWriteClient;
- this.shuffleServersForData = rssHandle.getShuffleServersForData();
- this.partitionToServers = rssHandle.getPartitionToServers();
+ this.shuffleServersForData = shuffleHandleInfo.getShuffleServersForData();
+ this.partitionToServers = shuffleHandleInfo.getPartitionToServers();
this.isMemoryShuffleEnabled =
isMemoryShuffleEnabled(sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()));
this.taskFailureCallback = taskFailureCallback;
+ this.taskContext = context;
}
public RssShuffleWriter(
@@ -159,7 +168,8 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
ShuffleWriteClient shuffleWriteClient,
RssShuffleHandle<K, V, C> rssHandle,
Function<String, Boolean> taskFailureCallback,
- TaskContext context) {
+ TaskContext context,
+ ShuffleHandleInfo shuffleHandleInfo) {
this(
appId,
shuffleId,
@@ -170,7 +180,9 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
sparkConf,
shuffleWriteClient,
rssHandle,
- taskFailureCallback);
+ taskFailureCallback,
+ shuffleHandleInfo,
+ context);
BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
final WriteBufferManager bufferManager =
new WriteBufferManager(
diff --git
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
index a6c25516a..f09223b1c 100644
---
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
+++
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
@@ -98,7 +98,8 @@ public class RssShuffleReaderTest extends
AbstractRssReaderTest {
10,
blockIdBitmap,
taskIdBitmap,
- rssConf));
+ rssConf,
+ partitionToServers));
validateResult(rssShuffleReaderSpy.read(), expectedData, 10);
}
diff --git
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index e60930d8c..13fb93f7a 100644
---
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -37,6 +37,7 @@ import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
+import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.memory.TaskMemoryManager;
@@ -45,6 +46,7 @@ import org.apache.spark.serializer.Serializer;
import org.apache.spark.shuffle.RssShuffleHandle;
import org.apache.spark.shuffle.RssShuffleManager;
import org.apache.spark.shuffle.RssSparkConfig;
+import org.apache.spark.shuffle.ShuffleHandleInfo;
import org.junit.jupiter.api.Test;
import org.apache.uniffle.client.api.ShuffleWriteClient;
@@ -93,6 +95,8 @@ public class RssShuffleWriterTest {
when(mockPartitioner.numPartitions()).thenReturn(2);
when(mockHandle.getPartitionToServers()).thenReturn(Maps.newHashMap());
TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class);
+ TaskContext contextMock = mock(TaskContext.class);
+ ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
WriteBufferManager bufferManager =
@@ -119,7 +123,9 @@ public class RssShuffleWriterTest {
manager,
conf,
mockShuffleWriteClient,
- mockHandle);
+ mockHandle,
+ mockShuffleHandleInfo,
+ contextMock);
// case 1: all blocks are sent successfully
manager.addSuccessBlockIds(taskId, Sets.newHashSet(1L, 2L, 3L));
@@ -274,6 +280,8 @@ public class RssShuffleWriterTest {
null);
WriteBufferManager bufferManagerSpy = spy(bufferManager);
doReturn(1000000L).when(bufferManagerSpy).acquireMemory(anyLong());
+ TaskContext contextMock = mock(TaskContext.class);
+ ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
RssShuffleWriter<String, String, String> rssShuffleWriter =
new RssShuffleWriter<>(
@@ -286,7 +294,9 @@ public class RssShuffleWriterTest {
manager,
conf,
mockShuffleWriteClient,
- mockHandle);
+ mockHandle,
+ mockShuffleHandleInfo,
+ contextMock);
RssShuffleWriter<String, String, String> rssShuffleWriterSpy =
spy(rssShuffleWriter);
doNothing().when(rssShuffleWriterSpy).sendCommit();
@@ -382,6 +392,8 @@ public class RssShuffleWriterTest {
RssShuffleHandle<String, String, String> mockHandle =
mock(RssShuffleHandle.class);
when(mockHandle.getDependency()).thenReturn(mockDependency);
ShuffleWriteClient mockWriteClient = mock(ShuffleWriteClient.class);
+ TaskContext contextMock = mock(TaskContext.class);
+ ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
RssShuffleWriter<String, String, String> writer =
new RssShuffleWriter<>(
@@ -394,7 +406,9 @@ public class RssShuffleWriterTest {
manager,
conf,
mockWriteClient,
- mockHandle);
+ mockHandle,
+ mockShuffleHandleInfo,
+ contextMock);
List<ShuffleBlockInfo> shuffleBlockInfoList = createShuffleBlockList(1,
31);
writer.postBlockEvent(shuffleBlockInfoList);
Thread.sleep(500);
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 9ec5e90ae..d9e730e7e 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -36,7 +36,6 @@ import scala.collection.Iterator;
import scala.collection.Seq;
import com.google.common.annotations.VisibleForTesting;
-import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import org.apache.hadoop.conf.Configuration;
@@ -59,9 +58,15 @@ import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
+import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
+import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
+import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse;
import org.apache.uniffle.client.util.ClientUtils;
+import org.apache.uniffle.client.util.RssClientConfig;
+import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.PartitionRange;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleAssignmentsInfo;
@@ -111,8 +116,8 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
private Set<String> failedTaskIds = Sets.newConcurrentHashSet();
private DataPusher dataPusher;
- private final Map<Integer, Integer> shuffleIdToPartitionNum =
Maps.newConcurrentMap();
- private final Map<Integer, Integer> shuffleIdToNumMapTasks =
Maps.newConcurrentMap();
+ private final Map<Integer, Integer> shuffleIdToPartitionNum =
JavaUtils.newConcurrentMap();
+ private final Map<Integer, Integer> shuffleIdToNumMapTasks =
JavaUtils.newConcurrentMap();
private ShuffleManagerGrpcService service;
private GrpcServer shuffleManagerServer;
@@ -121,6 +126,15 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
protected ShuffleWriteClient shuffleWriteClient;
+ private ShuffleManagerClient shuffleManagerClient;
+ /**
+ * Mapping between ShuffleId and ShuffleServer list. ShuffleServer list is
dynamically allocated.
+ * ShuffleServer is not obtained from RssShuffleHandle, but from this
mapping.
+ */
+ private Map<Integer, ShuffleHandleInfo> shuffleIdToShuffleHandleInfo;
+ /** Whether to enable the dynamic shuffleServer function rewrite and reread
functions */
+ private boolean rssResubmitStage;
+
public RssShuffleManager(SparkConf conf, boolean isDriver) {
this.sparkConf = conf;
boolean supportsRelocation =
@@ -209,11 +223,13 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
taskToSuccessBlockIds = JavaUtils.newConcurrentMap();
taskToFailedBlockIds = JavaUtils.newConcurrentMap();
this.taskToFailedBlockIdsAndServer = JavaUtils.newConcurrentMap();
+ this.rssResubmitStage =
+ rssConf.getBoolean(RssClientConfig.RSS_RESUBMIT_STAGE, false)
+ && RssSparkShuffleUtils.isStageResubmitSupported();
if (isDriver) {
heartBeatScheduledExecutorService =
ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat");
- if (sparkConf.get(RssSparkConfig.RSS_RESUBMIT_STAGE)
- && RssSparkShuffleUtils.isStageResubmitSupported()) {
+ if (rssResubmitStage) {
LOG.info("stage resubmit is supported and enabled");
// start shuffle manager server
rssConf.set(RPC_SERVER_PORT, 0);
@@ -244,6 +260,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
failedTaskIds,
poolSize,
keepAliveTime);
+ this.shuffleIdToShuffleHandleInfo = JavaUtils.newConcurrentMap();
}
public CompletableFuture<Long> sendData(AddBlockEvent event) {
@@ -428,6 +445,11 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
shuffleIdToPartitionNum.putIfAbsent(shuffleId,
dependency.partitioner().numPartitions());
shuffleIdToNumMapTasks.putIfAbsent(shuffleId,
dependency.rdd().partitions().length);
+ if (rssResubmitStage) {
+ ShuffleHandleInfo handleInfo =
+ new ShuffleHandleInfo(shuffleId, partitionToServers, remoteStorage);
+ shuffleIdToShuffleHandleInfo.put(shuffleId, handleInfo);
+ }
Broadcast<ShuffleHandleInfo> hdlInfoBd =
RssSparkShuffleUtils.broadcastShuffleHdlInfo(
RssSparkShuffleUtils.getActiveSparkContext(),
@@ -454,14 +476,22 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
RssShuffleHandle<K, V, ?> rssHandle = (RssShuffleHandle<K, V, ?>) handle;
setPusherAppId(rssHandle);
int shuffleId = rssHandle.getShuffleId();
- String taskId = "" + context.taskAttemptId() + "_" +
context.attemptNumber();
-
ShuffleWriteMetrics writeMetrics;
if (metrics != null) {
writeMetrics = new WriteMetrics(metrics);
} else {
writeMetrics = context.taskMetrics().shuffleWriteMetrics();
}
+ ShuffleHandleInfo shuffleHandleInfo;
+ if (rssResubmitStage) {
+ // Get the ShuffleServer list from the Driver based on the shuffleId
+ shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
+ } else {
+ shuffleHandleInfo =
+ new ShuffleHandleInfo(
+ shuffleId, rssHandle.getPartitionToServers(),
rssHandle.getRemoteStorage());
+ }
+ String taskId = "" + context.taskAttemptId() + "_" +
context.attemptNumber();
LOG.info("RssHandle appId {} shuffleId {} ", rssHandle.getAppId(),
rssHandle.getShuffleId());
return new RssShuffleWriter<>(
rssHandle.getAppId(),
@@ -474,7 +504,8 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
shuffleWriteClient,
rssHandle,
this::markFailedTask,
- context);
+ context,
+ shuffleHandleInfo);
}
public void setPusherAppId(RssShuffleHandle rssShuffleHandle) {
@@ -582,8 +613,19 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
RssShuffleHandle<K, ?, C> rssShuffleHandle = (RssShuffleHandle<K, ?, C>)
handle;
final int partitionNum =
rssShuffleHandle.getDependency().partitioner().numPartitions();
int shuffleId = rssShuffleHandle.getShuffleId();
+ ShuffleHandleInfo shuffleHandleInfo;
+ if (rssResubmitStage) {
+ // Get the ShuffleServer list from the Driver based on the shuffleId
+ shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
+ } else {
+ shuffleHandleInfo =
+ new ShuffleHandleInfo(
+ shuffleId,
+ rssShuffleHandle.getPartitionToServers(),
+ rssShuffleHandle.getRemoteStorage());
+ }
Map<Integer, List<ShuffleServerInfo>> allPartitionToServers =
- rssShuffleHandle.getPartitionToServers();
+ shuffleHandleInfo.getPartitionToServers();
Map<Integer, List<ShuffleServerInfo>> requirePartitionToServers =
allPartitionToServers.entrySet().stream()
.filter(x -> x.getKey() >= startPartition && x.getKey() <
endPartition)
@@ -638,7 +680,8 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
taskIdBitmap,
readMetrics,
RssSparkConfig.toRssConf(sparkConf),
- dataDistributionType);
+ dataDistributionType,
+ allPartitionToServers);
}
@SuppressFBWarnings("REC_CATCH_EXCEPTION")
@@ -1011,4 +1054,42 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
}
return result;
}
+
+ @Override
+ public ShuffleHandleInfo getShuffleHandleInfoByShuffleId(int shuffleId) {
+ return shuffleIdToShuffleHandleInfo.get(shuffleId);
+ }
+
+ private ShuffleManagerClient createShuffleManagerClient(String host, int
port) {
+ // Host can be inferred from `spark.driver.bindAddress`, which would be
set when SparkContext is
+ // constructed.
+ return ShuffleManagerClientFactory.getInstance()
+ .createShuffleManagerClient(ClientType.GRPC, host, port);
+ }
+
+ /**
+ * Get the ShuffleServer list from the Driver based on the shuffleId
+ *
+ * @param shuffleId shuffleId
+ * @return ShuffleHandleInfo
+ */
+ private ShuffleHandleInfo getRemoteShuffleHandleInfo(int shuffleId) {
+ ShuffleHandleInfo shuffleHandleInfo;
+ RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
+ String driver = rssConf.getString("driver.host", "");
+ int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
+ if (shuffleManagerClient == null) {
+ shuffleManagerClient = createShuffleManagerClient(driver, port);
+ }
+ RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest =
+ new RssPartitionToShuffleServerRequest(shuffleId);
+ RssPartitionToShuffleServerResponse rpcPartitionToShufflerServer =
+
shuffleManagerClient.getPartitionToShufflerServer(rssPartitionToShuffleServerRequest);
+ shuffleHandleInfo =
+ new ShuffleHandleInfo(
+ shuffleId,
+ rpcPartitionToShufflerServer.getPartitionToServers(),
+ rpcPartitionToShufflerServer.getRemoteStorageInfo());
+ return shuffleHandleInfo;
+ }
}
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index f8ed76f73..3d7b58bbb 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -95,7 +95,8 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
Roaring64NavigableMap taskIdBitmap,
ShuffleReadMetrics readMetrics,
RssConf rssConf,
- ShuffleDataDistributionType dataDistributionType) {
+ ShuffleDataDistributionType dataDistributionType,
+ Map<Integer, List<ShuffleServerInfo>> allPartitionToServers) {
this.appId = rssShuffleHandle.getAppId();
this.startPartition = startPartition;
this.endPartition = endPartition;
@@ -113,7 +114,7 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
this.taskIdBitmap = taskIdBitmap;
this.hadoopConf = hadoopConf;
this.readMetrics = readMetrics;
- this.partitionToShuffleServers = rssShuffleHandle.getPartitionToServers();
+ this.partitionToShuffleServers = allPartitionToServers;
this.rssConf = rssConf;
this.dataDistributionType = dataDistributionType;
}
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 330f56c8d..1c8b89305 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -52,6 +52,7 @@ import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.shuffle.RssShuffleHandle;
import org.apache.spark.shuffle.RssShuffleManager;
import org.apache.spark.shuffle.RssSparkConfig;
+import org.apache.spark.shuffle.ShuffleHandleInfo;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.storage.BlockManagerId;
import org.slf4j.Logger;
@@ -89,6 +90,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
private final boolean isMemoryShuffleEnabled;
private final Function<String, Boolean> taskFailureCallback;
private final Set<Long> blockIds = Sets.newConcurrentHashSet();
+ private TaskContext taskContext;
/** used by columnar rss shuffle writer implementation */
protected final long taskAttemptId;
@@ -109,7 +111,9 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
RssShuffleManager shuffleManager,
SparkConf sparkConf,
ShuffleWriteClient shuffleWriteClient,
- RssShuffleHandle<K, V, C> rssHandle) {
+ RssShuffleHandle<K, V, C> rssHandle,
+ ShuffleHandleInfo shuffleHandleInfo,
+ TaskContext context) {
this(
appId,
shuffleId,
@@ -120,7 +124,9 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
sparkConf,
shuffleWriteClient,
rssHandle,
- (tid) -> true);
+ (tid) -> true,
+ shuffleHandleInfo,
+ context);
this.bufferManager = bufferManager;
}
@@ -134,7 +140,9 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
SparkConf sparkConf,
ShuffleWriteClient shuffleWriteClient,
RssShuffleHandle<K, V, C> rssHandle,
- Function<String, Boolean> taskFailureCallback) {
+ Function<String, Boolean> taskFailureCallback,
+ ShuffleHandleInfo shuffleHandleInfo,
+ TaskContext context) {
LOG.warn("RssShuffle start write taskAttemptId data" + taskAttemptId);
this.shuffleManager = shuffleManager;
this.appId = appId;
@@ -151,13 +159,14 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
this.bitmapSplitNum =
sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM);
this.partitionToBlockIds = Maps.newHashMap();
this.shuffleWriteClient = shuffleWriteClient;
- this.shuffleServersForData = rssHandle.getShuffleServersForData();
+ this.shuffleServersForData = shuffleHandleInfo.getShuffleServersForData();
this.partitionLengths = new long[partitioner.numPartitions()];
Arrays.fill(partitionLengths, 0);
- partitionToServers = rssHandle.getPartitionToServers();
+ partitionToServers = shuffleHandleInfo.getPartitionToServers();
this.isMemoryShuffleEnabled =
isMemoryShuffleEnabled(sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()));
this.taskFailureCallback = taskFailureCallback;
+ this.taskContext = context;
}
public RssShuffleWriter(
@@ -171,7 +180,8 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
ShuffleWriteClient shuffleWriteClient,
RssShuffleHandle<K, V, C> rssHandle,
Function<String, Boolean> taskFailureCallback,
- TaskContext context) {
+ TaskContext context,
+ ShuffleHandleInfo shuffleHandleInfo) {
this(
appId,
shuffleId,
@@ -182,7 +192,9 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
sparkConf,
shuffleWriteClient,
rssHandle,
- taskFailureCallback);
+ taskFailureCallback,
+ shuffleHandleInfo,
+ context);
BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
final WriteBufferManager bufferManager =
new WriteBufferManager(
@@ -191,7 +203,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
taskAttemptId,
bufferOptions,
rssHandle.getDependency().serializer(),
- rssHandle.getPartitionToServers(),
+ shuffleHandleInfo.getPartitionToServers(),
context.taskMemoryManager(),
shuffleWriteMetrics,
RssSparkConfig.toRssConf(sparkConf),
diff --git
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
index a2f51065a..aaff4cb8e 100644
---
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
+++
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
@@ -109,7 +109,8 @@ public class RssShuffleReaderTest extends
AbstractRssReaderTest {
taskIdBitmap,
new ShuffleReadMetrics(),
rssConf,
- ShuffleDataDistributionType.NORMAL));
+ ShuffleDataDistributionType.NORMAL,
+ partitionToServers));
validateResult(rssShuffleReaderSpy.read(), expectedData, 10);
writeTestData(
@@ -131,7 +132,8 @@ public class RssShuffleReaderTest extends
AbstractRssReaderTest {
taskIdBitmap,
new ShuffleReadMetrics(),
rssConf,
- ShuffleDataDistributionType.NORMAL));
+ ShuffleDataDistributionType.NORMAL,
+ partitionToServers));
validateResult(rssShuffleReaderSpy1.read(), expectedData, 18);
RssShuffleReader<String, String> rssShuffleReaderSpy2 =
@@ -150,7 +152,8 @@ public class RssShuffleReaderTest extends
AbstractRssReaderTest {
Roaring64NavigableMap.bitmapOf(),
new ShuffleReadMetrics(),
rssConf,
- ShuffleDataDistributionType.NORMAL));
+ ShuffleDataDistributionType.NORMAL,
+ partitionToServers));
validateResult(rssShuffleReaderSpy2.read(), Maps.newHashMap(), 0);
}
}
diff --git
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index ca9936288..206ceb33f 100644
---
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -39,6 +39,7 @@ import com.google.common.collect.Sets;
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.memory.TaskMemoryManager;
@@ -47,6 +48,7 @@ import org.apache.spark.serializer.Serializer;
import org.apache.spark.shuffle.RssShuffleHandle;
import org.apache.spark.shuffle.RssShuffleManager;
import org.apache.spark.shuffle.RssSparkConfig;
+import org.apache.spark.shuffle.ShuffleHandleInfo;
import org.apache.spark.shuffle.TestUtils;
import org.awaitility.Awaitility;
import org.junit.jupiter.api.Test;
@@ -111,6 +113,8 @@ public class RssShuffleWriterTest {
RssSparkConfig.toRssConf(conf));
WriteBufferManager bufferManagerSpy = spy(bufferManager);
+ TaskContext contextMock = mock(TaskContext.class);
+ ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
RssShuffleWriter<String, String, String> rssShuffleWriter =
new RssShuffleWriter<>(
"appId",
@@ -122,7 +126,9 @@ public class RssShuffleWriterTest {
manager,
conf,
mockShuffleWriteClient,
- mockHandle);
+ mockHandle,
+ mockShuffleHandleInfo,
+ contextMock);
doReturn(1000000L).when(bufferManagerSpy).acquireMemory(anyLong());
// case 1: all blocks are sent successfully
@@ -253,6 +259,8 @@ public class RssShuffleWriterTest {
when(mockDependency.serializer()).thenReturn(kryoSerializer);
when(mockDependency.partitioner()).thenReturn(mockPartitioner);
when(mockPartitioner.numPartitions()).thenReturn(1);
+ TaskContext contextMock = mock(TaskContext.class);
+ ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
RssShuffleWriter<String, String, String> rssShuffleWriter =
new RssShuffleWriter<>(
@@ -265,7 +273,9 @@ public class RssShuffleWriterTest {
manager,
conf,
mockShuffleWriteClient,
- mockHandle);
+ mockHandle,
+ mockShuffleHandleInfo,
+ contextMock);
rssShuffleWriter.getBufferManager().setSpillFunc(rssShuffleWriter::processShuffleBlockInfos);
MutableList<Product2<String, String>> data = new MutableList<>();
@@ -377,6 +387,8 @@ public class RssShuffleWriterTest {
bufferManager.setTaskId("taskId");
WriteBufferManager bufferManagerSpy = spy(bufferManager);
+ TaskContext contextMock = mock(TaskContext.class);
+ ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
RssShuffleWriter<String, String, String> rssShuffleWriter =
new RssShuffleWriter<>(
"appId",
@@ -388,7 +400,9 @@ public class RssShuffleWriterTest {
manager,
conf,
mockShuffleWriteClient,
- mockHandle);
+ mockHandle,
+ mockShuffleHandleInfo,
+ contextMock);
doReturn(1000000L).when(bufferManagerSpy).acquireMemory(anyLong());
RssShuffleWriter<String, String, String> rssShuffleWriterSpy =
spy(rssShuffleWriter);
@@ -484,6 +498,8 @@ public class RssShuffleWriterTest {
RssShuffleHandle<String, String, String> mockHandle =
mock(RssShuffleHandle.class);
when(mockHandle.getDependency()).thenReturn(mockDependency);
+ TaskContext contextMock = mock(TaskContext.class);
+ ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
ShuffleWriteClient mockWriteClient = mock(ShuffleWriteClient.class);
List<ShuffleBlockInfo> shuffleBlockInfoList = createShuffleBlockList(1,
31);
@@ -498,7 +514,9 @@ public class RssShuffleWriterTest {
mockShuffleManager,
conf,
mockWriteClient,
- mockHandle);
+ mockHandle,
+ mockShuffleHandleInfo,
+ contextMock);
writer.postBlockEvent(shuffleBlockInfoList);
Awaitility.await().timeout(Duration.ofSeconds(1)).until(() ->
events.size() == 1);
assertEquals(1, events.get(0).getShuffleDataInfoList().size());
diff --git
a/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java
b/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java
index 8022bf98c..bfe99eafa 100644
--- a/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java
+++ b/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java
@@ -18,6 +18,10 @@
package org.apache.uniffle.common;
import java.io.Serializable;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import org.apache.uniffle.proto.RssProtos;
public class ShuffleServerInfo implements Serializable {
@@ -108,4 +112,37 @@ public class ShuffleServerInfo implements Serializable {
+ "]}";
}
}
+
+ private static ShuffleServerInfo convertToShuffleServerId(
+ RssProtos.ShuffleServerId shuffleServerId) {
+ ShuffleServerInfo shuffleServerInfo =
+ new ShuffleServerInfo(
+ shuffleServerId.getId(), shuffleServerId.getIp(),
shuffleServerId.getPort(), 0);
+ return shuffleServerInfo;
+ }
+
+ private static RssProtos.ShuffleServerId convertToShuffleServerId(
+ ShuffleServerInfo shuffleServerInfo) {
+ RssProtos.ShuffleServerId shuffleServerId =
+ RssProtos.ShuffleServerId.newBuilder()
+ .setId(shuffleServerInfo.getId())
+ .setIp(shuffleServerInfo.getHost())
+ .setPort(shuffleServerInfo.grpcPort)
+ .setNettyPort(shuffleServerInfo.nettyPort)
+ .build();
+ return shuffleServerId;
+ }
+
+ public static List<ShuffleServerInfo>
fromProto(List<RssProtos.ShuffleServerId> servers) {
+ return servers.stream()
+ .map(server -> convertToShuffleServerId(server))
+ .collect(Collectors.toList());
+ }
+
+ public static List<RssProtos.ShuffleServerId> toProto(
+ List<ShuffleServerInfo> shuffleServerInfos) {
+ return shuffleServerInfos.stream()
+ .map(server -> convertToShuffleServerId(server))
+ .collect(Collectors.toList());
+ }
}
diff --git
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java
index 419efb441..5e95bc009 100644
---
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java
+++
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java
@@ -47,6 +47,7 @@ public class RSSStageResubmitTest extends
SparkIntegrationTestBase {
dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key(),
StorageType.MEMORY_LOCALFILE.name());
dynamicConf.put(
RssSparkConfig.SPARK_RSS_CONFIG_PREFIX +
RssClientConfig.RSS_RESUBMIT_STAGE, "true");
+ dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key(),
StorageType.MEMORY_LOCALFILE.name());
addDynamicConf(coordinatorConf, dynamicConf);
createCoordinatorServer(coordinatorConf);
ShuffleServerConf shuffleServerConf = getShuffleServerConf();
@@ -81,6 +82,8 @@ public class RSSStageResubmitTest extends
SparkIntegrationTestBase {
@Override
public void updateSparkConfCustomer(SparkConf sparkConf) {
+ sparkConf.set(
+ RssSparkConfig.SPARK_RSS_CONFIG_PREFIX +
RssClientConfig.RSS_RESUBMIT_STAGE, "true");
sparkConf.set("spark.task.maxFailures", String.valueOf(maxTaskFailures));
}
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
index bd3c817db..997a8e0bc 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
@@ -19,10 +19,21 @@ package org.apache.uniffle.client.api;
import java.io.Closeable;
+import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
+import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse;
import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
public interface ShuffleManagerClient extends Closeable {
RssReportShuffleFetchFailureResponse reportShuffleFetchFailure(
RssReportShuffleFetchFailureRequest request);
+
+ /**
+ * Gets the mapping between partitions and ShuffleServer from the
ShuffleManager server
+ *
+ * @param req request
+ * @return RssPartitionToShuffleServerResponse
+ */
+ RssPartitionToShuffleServerResponse getPartitionToShufflerServer(
+ RssPartitionToShuffleServerRequest req);
}
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
index 6b03e283a..56962c885 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
@@ -23,10 +23,13 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.uniffle.client.api.ShuffleManagerClient;
+import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
+import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse;
import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
import org.apache.uniffle.common.config.RssBaseConf;
import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.proto.RssProtos;
import org.apache.uniffle.proto.RssProtos.ReportShuffleFetchFailureRequest;
import org.apache.uniffle.proto.RssProtos.ReportShuffleFetchFailureResponse;
import org.apache.uniffle.proto.ShuffleManagerGrpc;
@@ -74,4 +77,15 @@ public class ShuffleManagerGrpcClient extends GrpcClient
implements ShuffleManag
throw new RssException(msg, e);
}
}
+
+ @Override
+ public RssPartitionToShuffleServerResponse getPartitionToShufflerServer(
+ RssPartitionToShuffleServerRequest req) {
+ RssProtos.PartitionToShuffleServerRequest protoRequest = req.toProto();
+ RssProtos.PartitionToShuffleServerResponse partitionToShufflerServer =
+ getBlockingStub().getPartitionToShufflerServer(protoRequest);
+ RssPartitionToShuffleServerResponse rssPartitionToShuffleServerResponse =
+
RssPartitionToShuffleServerResponse.fromProto(partitionToShufflerServer);
+ return rssPartitionToShuffleServerResponse;
+ }
}
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssPartitionToShuffleServerRequest.java
similarity index 58%
copy from
internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
copy to
internal-client/src/main/java/org/apache/uniffle/client/request/RssPartitionToShuffleServerRequest.java
index bd3c817db..62c3ec8f6 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssPartitionToShuffleServerRequest.java
@@ -15,14 +15,25 @@
* limitations under the License.
*/
-package org.apache.uniffle.client.api;
+package org.apache.uniffle.client.request;
-import java.io.Closeable;
+import org.apache.uniffle.proto.RssProtos;
-import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
-import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
+public class RssPartitionToShuffleServerRequest {
+ private int shuffleId;
-public interface ShuffleManagerClient extends Closeable {
- RssReportShuffleFetchFailureResponse reportShuffleFetchFailure(
- RssReportShuffleFetchFailureRequest request);
+ public RssPartitionToShuffleServerRequest(int shuffleId) {
+ this.shuffleId = shuffleId;
+ }
+
+ public int getShuffleId() {
+ return shuffleId;
+ }
+
+ public RssProtos.PartitionToShuffleServerRequest toProto() {
+ RssProtos.PartitionToShuffleServerRequest.Builder builder =
+ RssProtos.PartitionToShuffleServerRequest.newBuilder();
+ builder.setShuffleId(shuffleId);
+ return builder.build();
+ }
}
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssPartitionToShuffleServerResponse.java
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssPartitionToShuffleServerResponse.java
new file mode 100644
index 000000000..74d508eff
--- /dev/null
+++
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssPartitionToShuffleServerResponse.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.client.response;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
+
+import org.apache.uniffle.common.RemoteStorageInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.proto.RssProtos;
+
+public class RssPartitionToShuffleServerResponse extends ClientResponse {
+
+ private Map<Integer, List<ShuffleServerInfo>> partitionToServers;
+ private Set<ShuffleServerInfo> shuffleServersForData;
+ private RemoteStorageInfo remoteStorageInfo;
+
+ public RssPartitionToShuffleServerResponse(
+ StatusCode statusCode,
+ String message,
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers,
+ Set<ShuffleServerInfo> shuffleServersForData,
+ RemoteStorageInfo remoteStorageInfo) {
+ super(statusCode, message);
+ this.partitionToServers = partitionToServers;
+ this.remoteStorageInfo = remoteStorageInfo;
+ this.shuffleServersForData = shuffleServersForData;
+ }
+
+ public Map<Integer, List<ShuffleServerInfo>> getPartitionToServers() {
+ return partitionToServers;
+ }
+
+ public Set<ShuffleServerInfo> getShuffleServersForData() {
+ return shuffleServersForData;
+ }
+
+ public RemoteStorageInfo getRemoteStorageInfo() {
+ return remoteStorageInfo;
+ }
+
+ public static RssPartitionToShuffleServerResponse fromProto(
+ RssProtos.PartitionToShuffleServerResponse response) {
+ Map<Integer, RssProtos.GetShuffleServerListResponse>
partitionToShuffleServerMap =
+ response.getPartitionToShuffleServerMap();
+ Map<Integer, List<ShuffleServerInfo>> rpcPartitionToShuffleServerInfos =
Maps.newHashMap();
+ Set<Map.Entry<Integer, RssProtos.GetShuffleServerListResponse>> entries =
+ partitionToShuffleServerMap.entrySet();
+ for (Map.Entry<Integer, RssProtos.GetShuffleServerListResponse> entry :
entries) {
+ Integer partitionId = entry.getKey();
+ List<ShuffleServerInfo> shuffleServerInfos = Lists.newArrayList();
+ List<? extends RssProtos.ShuffleServerIdOrBuilder> serversOrBuilderList =
+ entry.getValue().getServersOrBuilderList();
+ for (RssProtos.ShuffleServerIdOrBuilder shuffleServerIdOrBuilder :
serversOrBuilderList) {
+ shuffleServerInfos.add(
+ new ShuffleServerInfo(
+ shuffleServerIdOrBuilder.getId(),
+ shuffleServerIdOrBuilder.getIp(),
+ shuffleServerIdOrBuilder.getPort(),
+ shuffleServerIdOrBuilder.getNettyPort()));
+ }
+
+ rpcPartitionToShuffleServerInfos.put(partitionId, shuffleServerInfos);
+ }
+ Set<ShuffleServerInfo> rpcShuffleServersForData = Sets.newHashSet();
+ for (List<ShuffleServerInfo> ssis :
rpcPartitionToShuffleServerInfos.values()) {
+ rpcShuffleServersForData.addAll(ssis);
+ }
+ RssProtos.RemoteStorageInfo protoRemoteStorageInfo =
response.getRemoteStorageInfo();
+ RemoteStorageInfo rpcRemoteStorageInfo =
+ new RemoteStorageInfo(
+ protoRemoteStorageInfo.getPath(),
protoRemoteStorageInfo.getConfItemsMap());
+ return new RssPartitionToShuffleServerResponse(
+ StatusCode.valueOf(response.getStatus().name()),
+ response.getMsg(),
+ rpcPartitionToShuffleServerInfos,
+ rpcShuffleServersForData,
+ rpcRemoteStorageInfo);
+ }
+}
diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto
index 1c8dfdd24..31e29bdae 100644
--- a/proto/src/main/proto/Rss.proto
+++ b/proto/src/main/proto/Rss.proto
@@ -497,6 +497,8 @@ message CancelDecommissionResponse {
// per application.
service ShuffleManager {
rpc reportShuffleFetchFailure (ReportShuffleFetchFailureRequest) returns
(ReportShuffleFetchFailureResponse);
+ // Gets the mapping between partitions and ShuffleServer from the
ShuffleManager server
+ rpc getPartitionToShufflerServer(PartitionToShuffleServerRequest) returns
(PartitionToShuffleServerResponse);
}
message ReportShuffleFetchFailureRequest {
@@ -516,3 +518,19 @@ message ReportShuffleFetchFailureResponse {
bool reSubmitWholeStage = 2;
string msg = 3;
}
+
+message PartitionToShuffleServerRequest {
+ int32 shuffleId = 2;
+}
+
+message PartitionToShuffleServerResponse {
+ StatusCode status = 1;
+ map<int32,GetShuffleServerListResponse> partitionToShuffleServer = 2;
+ RemoteStorageInfo remote_storage_info = 3;
+ string msg = 4;
+}
+
+message RemoteStorageInfo{
+ string path = 1;
+ map<string, string> confItems = 2;
+}