This is an automated email from the ASF dual-hosted git repository.

jerrylei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new e9a62eeaf [#825][part-3] feat(spark): Get the ShuffleServer 
corresponding to the partition from ShuffleManager. (#1141)
e9a62eeaf is described below

commit e9a62eeafbced0eb38f95f35b7fb885bf0c47e99
Author: yl09099 <[email protected]>
AuthorDate: Mon Oct 30 19:08:35 2023 +0800

    [#825][part-3] feat(spark): Get the ShuffleServer corresponding to the 
partition from ShuffleManager. (#1141)
    
    ### What changes were proposed in this pull request?
    
    ShuffleReader and ShuffleWriter get the ShuffleServer corresponding to the 
partition from ShuffleManager
    
    ### Why are the changes needed?
    
    Fix: #825
    
    ### Does this PR introduce any user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing UT.
---
 .../manager/RssShuffleManagerInterface.java        |   9 ++
 .../shuffle/manager/ShuffleManagerGrpcService.java |  52 +++++++++++
 .../shuffle/manager/DummyRssShuffleManager.java    |   7 ++
 .../apache/spark/shuffle/RssShuffleManager.java    |  98 ++++++++++++++++++--
 .../spark/shuffle/reader/RssShuffleReader.java     |   7 +-
 .../spark/shuffle/writer/RssShuffleWriter.java     |  26 ++++--
 .../spark/shuffle/reader/RssShuffleReaderTest.java |   3 +-
 .../spark/shuffle/writer/RssShuffleWriterTest.java |  20 +++-
 .../apache/spark/shuffle/RssShuffleManager.java    | 101 +++++++++++++++++++--
 .../spark/shuffle/reader/RssShuffleReader.java     |   5 +-
 .../spark/shuffle/writer/RssShuffleWriter.java     |  28 ++++--
 .../spark/shuffle/reader/RssShuffleReaderTest.java |   9 +-
 .../spark/shuffle/writer/RssShuffleWriterTest.java |  26 +++++-
 .../apache/uniffle/common/ShuffleServerInfo.java   |  37 ++++++++
 .../apache/uniffle/test/RSSStageResubmitTest.java  |   3 +
 .../uniffle/client/api/ShuffleManagerClient.java   |  11 +++
 .../client/impl/grpc/ShuffleManagerGrpcClient.java |  14 +++
 .../RssPartitionToShuffleServerRequest.java}       |  25 +++--
 .../RssPartitionToShuffleServerResponse.java       | 101 +++++++++++++++++++++
 proto/src/main/proto/Rss.proto                     |  18 ++++
 20 files changed, 544 insertions(+), 56 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
index 4308602dc..34009cbad 100644
--- 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
+++ 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
@@ -18,6 +18,7 @@
 package org.apache.uniffle.shuffle.manager;
 
 import org.apache.spark.SparkException;
+import org.apache.spark.shuffle.ShuffleHandleInfo;
 
 /**
  * This is a proxy interface that mainly delegates the un-registration of 
shuffles to the
@@ -54,4 +55,12 @@ public interface RssShuffleManagerInterface {
    * @throws SparkException
    */
   void unregisterAllMapOutput(int shuffleId) throws SparkException;
+
+  /**
+   * Get ShuffleHandleInfo with ShuffleId
+   *
+   * @param shuffleId
+   * @return ShuffleHandleInfo
+   */
+  ShuffleHandleInfo getShuffleHandleInfoByShuffleId(int shuffleId);
 }
diff --git 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
index 6dfd52e24..5c4b8795c 100644
--- 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
+++ 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
@@ -18,14 +18,18 @@
 package org.apache.uniffle.shuffle.manager;
 
 import java.util.Arrays;
+import java.util.List;
 import java.util.Map;
 import java.util.concurrent.locks.ReentrantReadWriteLock;
 import java.util.function.Supplier;
 
 import io.grpc.stub.StreamObserver;
+import org.apache.spark.shuffle.ShuffleHandleInfo;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import org.apache.uniffle.common.RemoteStorageInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.util.JavaUtils;
 import org.apache.uniffle.proto.RssProtos;
 import org.apache.uniffle.proto.ShuffleManagerGrpc.ShuffleManagerImplBase;
@@ -101,6 +105,54 @@ public class ShuffleManagerGrpcService extends 
ShuffleManagerImplBase {
     responseObserver.onCompleted();
   }
 
+  @Override
+  public void getPartitionToShufflerServer(
+      RssProtos.PartitionToShuffleServerRequest request,
+      StreamObserver<RssProtos.PartitionToShuffleServerResponse> 
responseObserver) {
+    RssProtos.PartitionToShuffleServerResponse reply;
+    RssProtos.StatusCode code;
+    int shuffleId = request.getShuffleId();
+    ShuffleHandleInfo shuffleHandleInfoByShuffleId =
+        shuffleManager.getShuffleHandleInfoByShuffleId(shuffleId);
+    if (shuffleHandleInfoByShuffleId != null) {
+      code = RssProtos.StatusCode.SUCCESS;
+      Map<Integer, List<ShuffleServerInfo>> partitionToServers =
+          shuffleHandleInfoByShuffleId.getPartitionToServers();
+      Map<Integer, RssProtos.GetShuffleServerListResponse> 
protopartitionToServers =
+          JavaUtils.newConcurrentMap();
+      for (Map.Entry<Integer, List<ShuffleServerInfo>> integerListEntry :
+          partitionToServers.entrySet()) {
+        List<RssProtos.ShuffleServerId> shuffleServerIds =
+            ShuffleServerInfo.toProto(integerListEntry.getValue());
+        RssProtos.GetShuffleServerListResponse getShuffleServerListResponse =
+            RssProtos.GetShuffleServerListResponse.newBuilder()
+                .addAllServers(shuffleServerIds)
+                .build();
+        protopartitionToServers.put(integerListEntry.getKey(), 
getShuffleServerListResponse);
+      }
+      RemoteStorageInfo remoteStorage = 
shuffleHandleInfoByShuffleId.getRemoteStorage();
+      RssProtos.RemoteStorageInfo.Builder protosRemoteStage =
+          RssProtos.RemoteStorageInfo.newBuilder()
+              .setPath(remoteStorage.getPath())
+              .putAllConfItems(remoteStorage.getConfItems());
+      reply =
+          RssProtos.PartitionToShuffleServerResponse.newBuilder()
+              .setStatus(code)
+              .putAllPartitionToShuffleServer(protopartitionToServers)
+              .setRemoteStorageInfo(protosRemoteStage)
+              .build();
+    } else {
+      code = RssProtos.StatusCode.INVALID_REQUEST;
+      reply =
+          RssProtos.PartitionToShuffleServerResponse.newBuilder()
+              .setStatus(code)
+              .putAllPartitionToShuffleServer(null)
+              .build();
+    }
+    responseObserver.onNext(reply);
+    responseObserver.onCompleted();
+  }
+
   /**
    * Remove the no longer used shuffle id's rss shuffle status. This is called 
when ShuffleManager
    * unregisters the corresponding shuffle id.
diff --git 
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
 
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
index 9e06da4e6..dfd4b69a1 100644
--- 
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
+++ 
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
@@ -20,6 +20,8 @@ package org.apache.uniffle.shuffle.manager;
 import java.util.LinkedHashSet;
 import java.util.Set;
 
+import org.apache.spark.shuffle.ShuffleHandleInfo;
+
 public class DummyRssShuffleManager implements RssShuffleManagerInterface {
   public Set<Integer> unregisteredShuffleIds = new LinkedHashSet<>();
 
@@ -47,4 +49,9 @@ public class DummyRssShuffleManager implements 
RssShuffleManagerInterface {
   public void unregisterAllMapOutput(int shuffleId) {
     unregisteredShuffleIds.add(shuffleId);
   }
+
+  @Override
+  public ShuffleHandleInfo getShuffleHandleInfoByShuffleId(int shuffleId) {
+    return null;
+  }
 }
diff --git 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 4efde2ec5..dcebe11f8 100644
--- 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -33,7 +33,6 @@ import scala.collection.Iterator;
 import scala.collection.Seq;
 
 import com.google.common.annotations.VisibleForTesting;
-import com.google.common.collect.Maps;
 import com.google.common.collect.Sets;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.spark.ShuffleDependency;
@@ -52,14 +51,21 @@ import org.roaringbitmap.longlong.Roaring64NavigableMap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import org.apache.uniffle.client.api.ShuffleManagerClient;
 import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.client.factory.ShuffleClientFactory;
+import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
+import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
+import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse;
 import org.apache.uniffle.client.util.ClientUtils;
+import org.apache.uniffle.client.util.RssClientConfig;
+import org.apache.uniffle.common.ClientType;
 import org.apache.uniffle.common.PartitionRange;
 import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.ShuffleAssignmentsInfo;
 import org.apache.uniffle.common.ShuffleDataDistributionType;
 import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssClientConf;
 import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.exception.RssFetchFailedException;
@@ -104,10 +110,19 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
   private DataPusher dataPusher;
   private final int maxConcurrencyPerPartitionToWrite;
 
-  private final Map<Integer, Integer> shuffleIdToPartitionNum = 
Maps.newConcurrentMap();
-  private final Map<Integer, Integer> shuffleIdToNumMapTasks = 
Maps.newConcurrentMap();
+  private final Map<Integer, Integer> shuffleIdToPartitionNum = 
JavaUtils.newConcurrentMap();
+  private final Map<Integer, Integer> shuffleIdToNumMapTasks = 
JavaUtils.newConcurrentMap();
   private GrpcServer shuffleManagerServer;
   private ShuffleManagerGrpcService service;
+  private ShuffleManagerClient shuffleManagerClient;
+  /**
+   * Mapping between ShuffleId and ShuffleServer list. ShuffleServer list is 
dynamically allocated.
+   * ShuffleServer is not obtained from RssShuffleHandle, but from this 
mapping.
+   */
+  private Map<Integer, ShuffleHandleInfo> shuffleIdToShuffleHandleInfo =
+      JavaUtils.newConcurrentMap();
+  /** Whether to enable the dynamic shuffleServer function rewrite and reread 
functions */
+  private boolean rssResubmitStage;
 
   public RssShuffleManager(SparkConf sparkConf, boolean isDriver) {
     if (sparkConf.getBoolean("spark.sql.adaptive.enabled", false)) {
@@ -183,12 +198,14 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     // shuffle cluster, we don't need shuffle data locality
     sparkConf.set("spark.shuffle.reduceLocality.enabled", "false");
     LOG.info("Disable shuffle data locality in RssShuffleManager.");
+    this.rssResubmitStage =
+        rssConf.getBoolean(RssClientConfig.RSS_RESUBMIT_STAGE, false)
+            && RssSparkShuffleUtils.isStageResubmitSupported();
     if (!sparkConf.getBoolean(RssSparkConfig.RSS_TEST_FLAG.key(), false)) {
       if (isDriver) {
         heartBeatScheduledExecutorService =
             
ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat");
-        if (sparkConf.get(RssSparkConfig.RSS_RESUBMIT_STAGE)
-            && RssSparkShuffleUtils.isStageResubmitSupported()) {
+        if (rssResubmitStage) {
           LOG.info("stage resubmit is supported and enabled");
           // start shuffle manager server
           rssConf.set(RPC_SERVER_PORT, 0);
@@ -330,6 +347,11 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
 
     shuffleIdToPartitionNum.putIfAbsent(shuffleId, 
dependency.partitioner().numPartitions());
     shuffleIdToNumMapTasks.putIfAbsent(shuffleId, 
dependency.rdd().partitions().length);
+    if (rssResubmitStage) {
+      ShuffleHandleInfo handleInfo =
+          new ShuffleHandleInfo(shuffleId, partitionToServers, remoteStorage);
+      shuffleIdToShuffleHandleInfo.put(shuffleId, handleInfo);
+    }
     Broadcast<ShuffleHandleInfo> hdlInfoBd =
         RssSparkShuffleUtils.broadcastShuffleHdlInfo(
             RssSparkShuffleUtils.getActiveSparkContext(),
@@ -421,6 +443,15 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
 
       int shuffleId = rssHandle.getShuffleId();
       String taskId = "" + context.taskAttemptId() + "_" + 
context.attemptNumber();
+      ShuffleHandleInfo shuffleHandleInfo;
+      if (rssResubmitStage) {
+        // Get the ShuffleServer list from the Driver based on the shuffleId
+        shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
+      } else {
+        shuffleHandleInfo =
+            new ShuffleHandleInfo(
+                shuffleId, rssHandle.getPartitionToServers(), 
rssHandle.getRemoteStorage());
+      }
       ShuffleWriteMetrics writeMetrics = 
context.taskMetrics().shuffleWriteMetrics();
       return new RssShuffleWriter<>(
           rssHandle.getAppId(),
@@ -433,7 +464,8 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
           shuffleWriteClient,
           rssHandle,
           this::markFailedTask,
-          context);
+          context,
+          shuffleHandleInfo);
     } else {
       throw new RssException("Unexpected ShuffleHandle:" + 
handle.getClass().getName());
     }
@@ -463,8 +495,19 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
               + startPartition
               + "]");
       start = System.currentTimeMillis();
+      ShuffleHandleInfo shuffleHandleInfo;
+      if (rssResubmitStage) {
+        // Get the ShuffleServer list from the Driver based on the shuffleId
+        shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
+      } else {
+        shuffleHandleInfo =
+            new ShuffleHandleInfo(
+                shuffleId,
+                rssShuffleHandle.getPartitionToServers(),
+                rssShuffleHandle.getRemoteStorage());
+      }
       Map<Integer, List<ShuffleServerInfo>> partitionToServers =
-          rssShuffleHandle.getPartitionToServers();
+          shuffleHandleInfo.getPartitionToServers();
       Roaring64NavigableMap blockIdBitmap =
           getShuffleResult(
               clientType,
@@ -501,7 +544,8 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
           partitionNum,
           blockIdBitmap,
           taskIdBitmap,
-          RssSparkConfig.toRssConf(sparkConf));
+          RssSparkConfig.toRssConf(sparkConf),
+          partitionToServers);
     } else {
       throw new RssException("Unexpected ShuffleHandle:" + 
handle.getClass().getName());
     }
@@ -712,4 +756,42 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     }
     return result;
   }
+
+  @Override
+  public ShuffleHandleInfo getShuffleHandleInfoByShuffleId(int shuffleId) {
+    return shuffleIdToShuffleHandleInfo.get(shuffleId);
+  }
+
+  private ShuffleManagerClient createShuffleManagerClient(String host, int 
port) {
+    // Host can be inferred from `spark.driver.bindAddress`, which would be 
set when SparkContext is
+    // constructed.
+    return ShuffleManagerClientFactory.getInstance()
+        .createShuffleManagerClient(ClientType.GRPC, host, port);
+  }
+
+  /**
+   * Get the ShuffleServer list from the Driver based on the shuffleId
+   *
+   * @param shuffleId shuffleId
+   * @return ShuffleHandleInfo
+   */
+  private ShuffleHandleInfo getRemoteShuffleHandleInfo(int shuffleId) {
+    ShuffleHandleInfo shuffleHandleInfo;
+    RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
+    String driver = rssConf.getString("driver.host", "");
+    int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
+    if (shuffleManagerClient == null) {
+      shuffleManagerClient = createShuffleManagerClient(driver, port);
+    }
+    RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest =
+        new RssPartitionToShuffleServerRequest(shuffleId);
+    RssPartitionToShuffleServerResponse rpcPartitionToShufflerServer =
+        
shuffleManagerClient.getPartitionToShufflerServer(rssPartitionToShuffleServerRequest);
+    shuffleHandleInfo =
+        new ShuffleHandleInfo(
+            shuffleId,
+            rpcPartitionToShufflerServer.getPartitionToServers(),
+            rpcPartitionToShufflerServer.getRemoteStorageInfo());
+    return shuffleHandleInfo;
+  }
 }
diff --git 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index 9130bc587..75855ba7a 100644
--- 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++ 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -18,6 +18,7 @@
 package org.apache.spark.shuffle.reader;
 
 import java.util.List;
+import java.util.Map;
 
 import scala.Function0;
 import scala.Option;
@@ -83,7 +84,8 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
       int partitionNum,
       Roaring64NavigableMap blockIdBitmap,
       Roaring64NavigableMap taskIdBitmap,
-      RssConf rssConf) {
+      RssConf rssConf,
+      Map<Integer, List<ShuffleServerInfo>> partitionToServers) {
     this.appId = rssShuffleHandle.getAppId();
     this.startPartition = startPartition;
     this.endPartition = endPartition;
@@ -98,8 +100,7 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
     this.blockIdBitmap = blockIdBitmap;
     this.taskIdBitmap = taskIdBitmap;
     this.hadoopConf = hadoopConf;
-    this.shuffleServerInfoList =
-        (List<ShuffleServerInfo>) 
(rssShuffleHandle.getPartitionToServers().get(startPartition));
+    this.shuffleServerInfoList = (List<ShuffleServerInfo>) 
(partitionToServers.get(startPartition));
     this.rssConf = rssConf;
     expectedTaskIdsBitmapFilterEnable = shuffleServerInfoList.size() > 1;
   }
diff --git 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 11f2dd3ba..fabac5651 100644
--- 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -50,6 +50,7 @@ import org.apache.spark.scheduler.MapStatus$;
 import org.apache.spark.shuffle.RssShuffleHandle;
 import org.apache.spark.shuffle.RssShuffleManager;
 import org.apache.spark.shuffle.RssSparkConfig;
+import org.apache.spark.shuffle.ShuffleHandleInfo;
 import org.apache.spark.shuffle.ShuffleWriter;
 import org.apache.spark.storage.BlockManagerId;
 import org.slf4j.Logger;
@@ -89,6 +90,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private boolean isMemoryShuffleEnabled;
   private final Function<String, Boolean> taskFailureCallback;
   private final Set<Long> blockIds = Sets.newConcurrentHashSet();
+  private TaskContext taskContext;
 
   public RssShuffleWriter(
       String appId,
@@ -100,7 +102,9 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       RssShuffleManager shuffleManager,
       SparkConf sparkConf,
       ShuffleWriteClient shuffleWriteClient,
-      RssShuffleHandle<K, V, C> rssHandle) {
+      RssShuffleHandle<K, V, C> rssHandle,
+      ShuffleHandleInfo shuffleHandleInfo,
+      TaskContext context) {
     this(
         appId,
         shuffleId,
@@ -111,7 +115,9 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         sparkConf,
         shuffleWriteClient,
         rssHandle,
-        (tid) -> true);
+        (tid) -> true,
+        shuffleHandleInfo,
+        context);
     this.bufferManager = bufferManager;
   }
 
@@ -125,7 +131,9 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       SparkConf sparkConf,
       ShuffleWriteClient shuffleWriteClient,
       RssShuffleHandle<K, V, C> rssHandle,
-      Function<String, Boolean> taskFailureCallback) {
+      Function<String, Boolean> taskFailureCallback,
+      ShuffleHandleInfo shuffleHandleInfo,
+      TaskContext context) {
     this.appId = appId;
     this.shuffleId = shuffleId;
     this.taskId = taskId;
@@ -141,11 +149,12 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.bitmapSplitNum = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM);
     this.partitionToBlockIds = Maps.newHashMap();
     this.shuffleWriteClient = shuffleWriteClient;
-    this.shuffleServersForData = rssHandle.getShuffleServersForData();
-    this.partitionToServers = rssHandle.getPartitionToServers();
+    this.shuffleServersForData = shuffleHandleInfo.getShuffleServersForData();
+    this.partitionToServers = shuffleHandleInfo.getPartitionToServers();
     this.isMemoryShuffleEnabled =
         
isMemoryShuffleEnabled(sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()));
     this.taskFailureCallback = taskFailureCallback;
+    this.taskContext = context;
   }
 
   public RssShuffleWriter(
@@ -159,7 +168,8 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       ShuffleWriteClient shuffleWriteClient,
       RssShuffleHandle<K, V, C> rssHandle,
       Function<String, Boolean> taskFailureCallback,
-      TaskContext context) {
+      TaskContext context,
+      ShuffleHandleInfo shuffleHandleInfo) {
     this(
         appId,
         shuffleId,
@@ -170,7 +180,9 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         sparkConf,
         shuffleWriteClient,
         rssHandle,
-        taskFailureCallback);
+        taskFailureCallback,
+        shuffleHandleInfo,
+        context);
     BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
     final WriteBufferManager bufferManager =
         new WriteBufferManager(
diff --git 
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
 
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
index a6c25516a..f09223b1c 100644
--- 
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
+++ 
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
@@ -98,7 +98,8 @@ public class RssShuffleReaderTest extends 
AbstractRssReaderTest {
                 10,
                 blockIdBitmap,
                 taskIdBitmap,
-                rssConf));
+                rssConf,
+                partitionToServers));
 
     validateResult(rssShuffleReaderSpy.read(), expectedData, 10);
   }
diff --git 
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
 
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index e60930d8c..13fb93f7a 100644
--- 
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++ 
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -37,6 +37,7 @@ import org.apache.spark.Partitioner;
 import org.apache.spark.ShuffleDependency;
 import org.apache.spark.SparkConf;
 import org.apache.spark.SparkContext;
+import org.apache.spark.TaskContext;
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.executor.TaskMetrics;
 import org.apache.spark.memory.TaskMemoryManager;
@@ -45,6 +46,7 @@ import org.apache.spark.serializer.Serializer;
 import org.apache.spark.shuffle.RssShuffleHandle;
 import org.apache.spark.shuffle.RssShuffleManager;
 import org.apache.spark.shuffle.RssSparkConfig;
+import org.apache.spark.shuffle.ShuffleHandleInfo;
 import org.junit.jupiter.api.Test;
 
 import org.apache.uniffle.client.api.ShuffleWriteClient;
@@ -93,6 +95,8 @@ public class RssShuffleWriterTest {
     when(mockPartitioner.numPartitions()).thenReturn(2);
     when(mockHandle.getPartitionToServers()).thenReturn(Maps.newHashMap());
     TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class);
+    TaskContext contextMock = mock(TaskContext.class);
+    ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
 
     BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
     WriteBufferManager bufferManager =
@@ -119,7 +123,9 @@ public class RssShuffleWriterTest {
             manager,
             conf,
             mockShuffleWriteClient,
-            mockHandle);
+            mockHandle,
+            mockShuffleHandleInfo,
+            contextMock);
 
     // case 1: all blocks are sent successfully
     manager.addSuccessBlockIds(taskId, Sets.newHashSet(1L, 2L, 3L));
@@ -274,6 +280,8 @@ public class RssShuffleWriterTest {
             null);
     WriteBufferManager bufferManagerSpy = spy(bufferManager);
     doReturn(1000000L).when(bufferManagerSpy).acquireMemory(anyLong());
+    TaskContext contextMock = mock(TaskContext.class);
+    ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
 
     RssShuffleWriter<String, String, String> rssShuffleWriter =
         new RssShuffleWriter<>(
@@ -286,7 +294,9 @@ public class RssShuffleWriterTest {
             manager,
             conf,
             mockShuffleWriteClient,
-            mockHandle);
+            mockHandle,
+            mockShuffleHandleInfo,
+            contextMock);
 
     RssShuffleWriter<String, String, String> rssShuffleWriterSpy = 
spy(rssShuffleWriter);
     doNothing().when(rssShuffleWriterSpy).sendCommit();
@@ -382,6 +392,8 @@ public class RssShuffleWriterTest {
     RssShuffleHandle<String, String, String> mockHandle = 
mock(RssShuffleHandle.class);
     when(mockHandle.getDependency()).thenReturn(mockDependency);
     ShuffleWriteClient mockWriteClient = mock(ShuffleWriteClient.class);
+    TaskContext contextMock = mock(TaskContext.class);
+    ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
 
     RssShuffleWriter<String, String, String> writer =
         new RssShuffleWriter<>(
@@ -394,7 +406,9 @@ public class RssShuffleWriterTest {
             manager,
             conf,
             mockWriteClient,
-            mockHandle);
+            mockHandle,
+            mockShuffleHandleInfo,
+            contextMock);
     List<ShuffleBlockInfo> shuffleBlockInfoList = createShuffleBlockList(1, 
31);
     writer.postBlockEvent(shuffleBlockInfoList);
     Thread.sleep(500);
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 9ec5e90ae..d9e730e7e 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -36,7 +36,6 @@ import scala.collection.Iterator;
 import scala.collection.Seq;
 
 import com.google.common.annotations.VisibleForTesting;
-import com.google.common.collect.Maps;
 import com.google.common.collect.Sets;
 import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
 import org.apache.hadoop.conf.Configuration;
@@ -59,9 +58,15 @@ import org.roaringbitmap.longlong.Roaring64NavigableMap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import org.apache.uniffle.client.api.ShuffleManagerClient;
 import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.client.factory.ShuffleClientFactory;
+import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
+import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
+import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse;
 import org.apache.uniffle.client.util.ClientUtils;
+import org.apache.uniffle.client.util.RssClientConfig;
+import org.apache.uniffle.common.ClientType;
 import org.apache.uniffle.common.PartitionRange;
 import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.ShuffleAssignmentsInfo;
@@ -111,8 +116,8 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
   private Set<String> failedTaskIds = Sets.newConcurrentHashSet();
   private DataPusher dataPusher;
 
-  private final Map<Integer, Integer> shuffleIdToPartitionNum = 
Maps.newConcurrentMap();
-  private final Map<Integer, Integer> shuffleIdToNumMapTasks = 
Maps.newConcurrentMap();
+  private final Map<Integer, Integer> shuffleIdToPartitionNum = 
JavaUtils.newConcurrentMap();
+  private final Map<Integer, Integer> shuffleIdToNumMapTasks = 
JavaUtils.newConcurrentMap();
   private ShuffleManagerGrpcService service;
   private GrpcServer shuffleManagerServer;
 
@@ -121,6 +126,15 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
 
   protected ShuffleWriteClient shuffleWriteClient;
 
+  private ShuffleManagerClient shuffleManagerClient;
+  /**
+   * Mapping between ShuffleId and ShuffleServer list. ShuffleServer list is 
dynamically allocated.
+   * ShuffleServer is not obtained from RssShuffleHandle, but from this 
mapping.
+   */
+  private Map<Integer, ShuffleHandleInfo> shuffleIdToShuffleHandleInfo;
+  /** Whether to enable the dynamic shuffleServer function rewrite and reread 
functions */
+  private boolean rssResubmitStage;
+
   public RssShuffleManager(SparkConf conf, boolean isDriver) {
     this.sparkConf = conf;
     boolean supportsRelocation =
@@ -209,11 +223,13 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     taskToSuccessBlockIds = JavaUtils.newConcurrentMap();
     taskToFailedBlockIds = JavaUtils.newConcurrentMap();
     this.taskToFailedBlockIdsAndServer = JavaUtils.newConcurrentMap();
+    this.rssResubmitStage =
+        rssConf.getBoolean(RssClientConfig.RSS_RESUBMIT_STAGE, false)
+            && RssSparkShuffleUtils.isStageResubmitSupported();
     if (isDriver) {
       heartBeatScheduledExecutorService =
           ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat");
-      if (sparkConf.get(RssSparkConfig.RSS_RESUBMIT_STAGE)
-          && RssSparkShuffleUtils.isStageResubmitSupported()) {
+      if (rssResubmitStage) {
         LOG.info("stage resubmit is supported and enabled");
         // start shuffle manager server
         rssConf.set(RPC_SERVER_PORT, 0);
@@ -244,6 +260,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
             failedTaskIds,
             poolSize,
             keepAliveTime);
+    this.shuffleIdToShuffleHandleInfo = JavaUtils.newConcurrentMap();
   }
 
   public CompletableFuture<Long> sendData(AddBlockEvent event) {
@@ -428,6 +445,11 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
 
     shuffleIdToPartitionNum.putIfAbsent(shuffleId, 
dependency.partitioner().numPartitions());
     shuffleIdToNumMapTasks.putIfAbsent(shuffleId, 
dependency.rdd().partitions().length);
+    if (rssResubmitStage) {
+      ShuffleHandleInfo handleInfo =
+          new ShuffleHandleInfo(shuffleId, partitionToServers, remoteStorage);
+      shuffleIdToShuffleHandleInfo.put(shuffleId, handleInfo);
+    }
     Broadcast<ShuffleHandleInfo> hdlInfoBd =
         RssSparkShuffleUtils.broadcastShuffleHdlInfo(
             RssSparkShuffleUtils.getActiveSparkContext(),
@@ -454,14 +476,22 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     RssShuffleHandle<K, V, ?> rssHandle = (RssShuffleHandle<K, V, ?>) handle;
     setPusherAppId(rssHandle);
     int shuffleId = rssHandle.getShuffleId();
-    String taskId = "" + context.taskAttemptId() + "_" + 
context.attemptNumber();
-
     ShuffleWriteMetrics writeMetrics;
     if (metrics != null) {
       writeMetrics = new WriteMetrics(metrics);
     } else {
       writeMetrics = context.taskMetrics().shuffleWriteMetrics();
     }
+    ShuffleHandleInfo shuffleHandleInfo;
+    if (rssResubmitStage) {
+      // Get the ShuffleServer list from the Driver based on the shuffleId
+      shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
+    } else {
+      shuffleHandleInfo =
+          new ShuffleHandleInfo(
+              shuffleId, rssHandle.getPartitionToServers(), 
rssHandle.getRemoteStorage());
+    }
+    String taskId = "" + context.taskAttemptId() + "_" + 
context.attemptNumber();
     LOG.info("RssHandle appId {} shuffleId {} ", rssHandle.getAppId(), 
rssHandle.getShuffleId());
     return new RssShuffleWriter<>(
         rssHandle.getAppId(),
@@ -474,7 +504,8 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
         shuffleWriteClient,
         rssHandle,
         this::markFailedTask,
-        context);
+        context,
+        shuffleHandleInfo);
   }
 
   public void setPusherAppId(RssShuffleHandle rssShuffleHandle) {
@@ -582,8 +613,19 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     RssShuffleHandle<K, ?, C> rssShuffleHandle = (RssShuffleHandle<K, ?, C>) 
handle;
     final int partitionNum = 
rssShuffleHandle.getDependency().partitioner().numPartitions();
     int shuffleId = rssShuffleHandle.getShuffleId();
+    ShuffleHandleInfo shuffleHandleInfo;
+    if (rssResubmitStage) {
+      // Get the ShuffleServer list from the Driver based on the shuffleId
+      shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
+    } else {
+      shuffleHandleInfo =
+          new ShuffleHandleInfo(
+              shuffleId,
+              rssShuffleHandle.getPartitionToServers(),
+              rssShuffleHandle.getRemoteStorage());
+    }
     Map<Integer, List<ShuffleServerInfo>> allPartitionToServers =
-        rssShuffleHandle.getPartitionToServers();
+        shuffleHandleInfo.getPartitionToServers();
     Map<Integer, List<ShuffleServerInfo>> requirePartitionToServers =
         allPartitionToServers.entrySet().stream()
             .filter(x -> x.getKey() >= startPartition && x.getKey() < 
endPartition)
@@ -638,7 +680,8 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
         taskIdBitmap,
         readMetrics,
         RssSparkConfig.toRssConf(sparkConf),
-        dataDistributionType);
+        dataDistributionType,
+        allPartitionToServers);
   }
 
   @SuppressFBWarnings("REC_CATCH_EXCEPTION")
@@ -1011,4 +1054,42 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     }
     return result;
   }
+
+  @Override
+  public ShuffleHandleInfo getShuffleHandleInfoByShuffleId(int shuffleId) {
+    return shuffleIdToShuffleHandleInfo.get(shuffleId);
+  }
+
+  private ShuffleManagerClient createShuffleManagerClient(String host, int 
port) {
+    // Host can be inferred from `spark.driver.bindAddress`, which would be 
set when SparkContext is
+    // constructed.
+    return ShuffleManagerClientFactory.getInstance()
+        .createShuffleManagerClient(ClientType.GRPC, host, port);
+  }
+
+  /**
+   * Get the ShuffleServer list from the Driver based on the shuffleId
+   *
+   * @param shuffleId shuffleId
+   * @return ShuffleHandleInfo
+   */
+  private ShuffleHandleInfo getRemoteShuffleHandleInfo(int shuffleId) {
+    ShuffleHandleInfo shuffleHandleInfo;
+    RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
+    String driver = rssConf.getString("driver.host", "");
+    int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
+    if (shuffleManagerClient == null) {
+      shuffleManagerClient = createShuffleManagerClient(driver, port);
+    }
+    RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest =
+        new RssPartitionToShuffleServerRequest(shuffleId);
+    RssPartitionToShuffleServerResponse rpcPartitionToShufflerServer =
+        
shuffleManagerClient.getPartitionToShufflerServer(rssPartitionToShuffleServerRequest);
+    shuffleHandleInfo =
+        new ShuffleHandleInfo(
+            shuffleId,
+            rpcPartitionToShufflerServer.getPartitionToServers(),
+            rpcPartitionToShufflerServer.getRemoteStorageInfo());
+    return shuffleHandleInfo;
+  }
 }
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index f8ed76f73..3d7b58bbb 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -95,7 +95,8 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
       Roaring64NavigableMap taskIdBitmap,
       ShuffleReadMetrics readMetrics,
       RssConf rssConf,
-      ShuffleDataDistributionType dataDistributionType) {
+      ShuffleDataDistributionType dataDistributionType,
+      Map<Integer, List<ShuffleServerInfo>> allPartitionToServers) {
     this.appId = rssShuffleHandle.getAppId();
     this.startPartition = startPartition;
     this.endPartition = endPartition;
@@ -113,7 +114,7 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
     this.taskIdBitmap = taskIdBitmap;
     this.hadoopConf = hadoopConf;
     this.readMetrics = readMetrics;
-    this.partitionToShuffleServers = rssShuffleHandle.getPartitionToServers();
+    this.partitionToShuffleServers = allPartitionToServers;
     this.rssConf = rssConf;
     this.dataDistributionType = dataDistributionType;
   }
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 330f56c8d..1c8b89305 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -52,6 +52,7 @@ import org.apache.spark.scheduler.MapStatus;
 import org.apache.spark.shuffle.RssShuffleHandle;
 import org.apache.spark.shuffle.RssShuffleManager;
 import org.apache.spark.shuffle.RssSparkConfig;
+import org.apache.spark.shuffle.ShuffleHandleInfo;
 import org.apache.spark.shuffle.ShuffleWriter;
 import org.apache.spark.storage.BlockManagerId;
 import org.slf4j.Logger;
@@ -89,6 +90,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private final boolean isMemoryShuffleEnabled;
   private final Function<String, Boolean> taskFailureCallback;
   private final Set<Long> blockIds = Sets.newConcurrentHashSet();
+  private TaskContext taskContext;
 
   /** used by columnar rss shuffle writer implementation */
   protected final long taskAttemptId;
@@ -109,7 +111,9 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       RssShuffleManager shuffleManager,
       SparkConf sparkConf,
       ShuffleWriteClient shuffleWriteClient,
-      RssShuffleHandle<K, V, C> rssHandle) {
+      RssShuffleHandle<K, V, C> rssHandle,
+      ShuffleHandleInfo shuffleHandleInfo,
+      TaskContext context) {
     this(
         appId,
         shuffleId,
@@ -120,7 +124,9 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         sparkConf,
         shuffleWriteClient,
         rssHandle,
-        (tid) -> true);
+        (tid) -> true,
+        shuffleHandleInfo,
+        context);
     this.bufferManager = bufferManager;
   }
 
@@ -134,7 +140,9 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       SparkConf sparkConf,
       ShuffleWriteClient shuffleWriteClient,
       RssShuffleHandle<K, V, C> rssHandle,
-      Function<String, Boolean> taskFailureCallback) {
+      Function<String, Boolean> taskFailureCallback,
+      ShuffleHandleInfo shuffleHandleInfo,
+      TaskContext context) {
     LOG.warn("RssShuffle start write taskAttemptId data" + taskAttemptId);
     this.shuffleManager = shuffleManager;
     this.appId = appId;
@@ -151,13 +159,14 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.bitmapSplitNum = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM);
     this.partitionToBlockIds = Maps.newHashMap();
     this.shuffleWriteClient = shuffleWriteClient;
-    this.shuffleServersForData = rssHandle.getShuffleServersForData();
+    this.shuffleServersForData = shuffleHandleInfo.getShuffleServersForData();
     this.partitionLengths = new long[partitioner.numPartitions()];
     Arrays.fill(partitionLengths, 0);
-    partitionToServers = rssHandle.getPartitionToServers();
+    partitionToServers = shuffleHandleInfo.getPartitionToServers();
     this.isMemoryShuffleEnabled =
         
isMemoryShuffleEnabled(sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()));
     this.taskFailureCallback = taskFailureCallback;
+    this.taskContext = context;
   }
 
   public RssShuffleWriter(
@@ -171,7 +180,8 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       ShuffleWriteClient shuffleWriteClient,
       RssShuffleHandle<K, V, C> rssHandle,
       Function<String, Boolean> taskFailureCallback,
-      TaskContext context) {
+      TaskContext context,
+      ShuffleHandleInfo shuffleHandleInfo) {
     this(
         appId,
         shuffleId,
@@ -182,7 +192,9 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         sparkConf,
         shuffleWriteClient,
         rssHandle,
-        taskFailureCallback);
+        taskFailureCallback,
+        shuffleHandleInfo,
+        context);
     BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
     final WriteBufferManager bufferManager =
         new WriteBufferManager(
@@ -191,7 +203,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
             taskAttemptId,
             bufferOptions,
             rssHandle.getDependency().serializer(),
-            rssHandle.getPartitionToServers(),
+            shuffleHandleInfo.getPartitionToServers(),
             context.taskMemoryManager(),
             shuffleWriteMetrics,
             RssSparkConfig.toRssConf(sparkConf),
diff --git 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
index a2f51065a..aaff4cb8e 100644
--- 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
+++ 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
@@ -109,7 +109,8 @@ public class RssShuffleReaderTest extends 
AbstractRssReaderTest {
                 taskIdBitmap,
                 new ShuffleReadMetrics(),
                 rssConf,
-                ShuffleDataDistributionType.NORMAL));
+                ShuffleDataDistributionType.NORMAL,
+                partitionToServers));
     validateResult(rssShuffleReaderSpy.read(), expectedData, 10);
 
     writeTestData(
@@ -131,7 +132,8 @@ public class RssShuffleReaderTest extends 
AbstractRssReaderTest {
                 taskIdBitmap,
                 new ShuffleReadMetrics(),
                 rssConf,
-                ShuffleDataDistributionType.NORMAL));
+                ShuffleDataDistributionType.NORMAL,
+                partitionToServers));
     validateResult(rssShuffleReaderSpy1.read(), expectedData, 18);
 
     RssShuffleReader<String, String> rssShuffleReaderSpy2 =
@@ -150,7 +152,8 @@ public class RssShuffleReaderTest extends 
AbstractRssReaderTest {
                 Roaring64NavigableMap.bitmapOf(),
                 new ShuffleReadMetrics(),
                 rssConf,
-                ShuffleDataDistributionType.NORMAL));
+                ShuffleDataDistributionType.NORMAL,
+                partitionToServers));
     validateResult(rssShuffleReaderSpy2.read(), Maps.newHashMap(), 0);
   }
 }
diff --git 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index ca9936288..206ceb33f 100644
--- 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++ 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -39,6 +39,7 @@ import com.google.common.collect.Sets;
 import org.apache.spark.Partitioner;
 import org.apache.spark.ShuffleDependency;
 import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.executor.TaskMetrics;
 import org.apache.spark.memory.TaskMemoryManager;
@@ -47,6 +48,7 @@ import org.apache.spark.serializer.Serializer;
 import org.apache.spark.shuffle.RssShuffleHandle;
 import org.apache.spark.shuffle.RssShuffleManager;
 import org.apache.spark.shuffle.RssSparkConfig;
+import org.apache.spark.shuffle.ShuffleHandleInfo;
 import org.apache.spark.shuffle.TestUtils;
 import org.awaitility.Awaitility;
 import org.junit.jupiter.api.Test;
@@ -111,6 +113,8 @@ public class RssShuffleWriterTest {
             RssSparkConfig.toRssConf(conf));
     WriteBufferManager bufferManagerSpy = spy(bufferManager);
 
+    TaskContext contextMock = mock(TaskContext.class);
+    ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
     RssShuffleWriter<String, String, String> rssShuffleWriter =
         new RssShuffleWriter<>(
             "appId",
@@ -122,7 +126,9 @@ public class RssShuffleWriterTest {
             manager,
             conf,
             mockShuffleWriteClient,
-            mockHandle);
+            mockHandle,
+            mockShuffleHandleInfo,
+            contextMock);
     doReturn(1000000L).when(bufferManagerSpy).acquireMemory(anyLong());
 
     // case 1: all blocks are sent successfully
@@ -253,6 +259,8 @@ public class RssShuffleWriterTest {
     when(mockDependency.serializer()).thenReturn(kryoSerializer);
     when(mockDependency.partitioner()).thenReturn(mockPartitioner);
     when(mockPartitioner.numPartitions()).thenReturn(1);
+    TaskContext contextMock = mock(TaskContext.class);
+    ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
 
     RssShuffleWriter<String, String, String> rssShuffleWriter =
         new RssShuffleWriter<>(
@@ -265,7 +273,9 @@ public class RssShuffleWriterTest {
             manager,
             conf,
             mockShuffleWriteClient,
-            mockHandle);
+            mockHandle,
+            mockShuffleHandleInfo,
+            contextMock);
     
rssShuffleWriter.getBufferManager().setSpillFunc(rssShuffleWriter::processShuffleBlockInfos);
 
     MutableList<Product2<String, String>> data = new MutableList<>();
@@ -377,6 +387,8 @@ public class RssShuffleWriterTest {
     bufferManager.setTaskId("taskId");
 
     WriteBufferManager bufferManagerSpy = spy(bufferManager);
+    TaskContext contextMock = mock(TaskContext.class);
+    ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
     RssShuffleWriter<String, String, String> rssShuffleWriter =
         new RssShuffleWriter<>(
             "appId",
@@ -388,7 +400,9 @@ public class RssShuffleWriterTest {
             manager,
             conf,
             mockShuffleWriteClient,
-            mockHandle);
+            mockHandle,
+            mockShuffleHandleInfo,
+            contextMock);
     doReturn(1000000L).when(bufferManagerSpy).acquireMemory(anyLong());
 
     RssShuffleWriter<String, String, String> rssShuffleWriterSpy = 
spy(rssShuffleWriter);
@@ -484,6 +498,8 @@ public class RssShuffleWriterTest {
 
     RssShuffleHandle<String, String, String> mockHandle = 
mock(RssShuffleHandle.class);
     when(mockHandle.getDependency()).thenReturn(mockDependency);
+    TaskContext contextMock = mock(TaskContext.class);
+    ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
     ShuffleWriteClient mockWriteClient = mock(ShuffleWriteClient.class);
 
     List<ShuffleBlockInfo> shuffleBlockInfoList = createShuffleBlockList(1, 
31);
@@ -498,7 +514,9 @@ public class RssShuffleWriterTest {
             mockShuffleManager,
             conf,
             mockWriteClient,
-            mockHandle);
+            mockHandle,
+            mockShuffleHandleInfo,
+            contextMock);
     writer.postBlockEvent(shuffleBlockInfoList);
     Awaitility.await().timeout(Duration.ofSeconds(1)).until(() -> 
events.size() == 1);
     assertEquals(1, events.get(0).getShuffleDataInfoList().size());
diff --git 
a/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java 
b/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java
index 8022bf98c..bfe99eafa 100644
--- a/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java
+++ b/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java
@@ -18,6 +18,10 @@
 package org.apache.uniffle.common;
 
 import java.io.Serializable;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import org.apache.uniffle.proto.RssProtos;
 
 public class ShuffleServerInfo implements Serializable {
 
@@ -108,4 +112,37 @@ public class ShuffleServerInfo implements Serializable {
           + "]}";
     }
   }
+
+  private static ShuffleServerInfo convertToShuffleServerId(
+      RssProtos.ShuffleServerId shuffleServerId) {
+    ShuffleServerInfo shuffleServerInfo =
+        new ShuffleServerInfo(
+            shuffleServerId.getId(), shuffleServerId.getIp(), 
shuffleServerId.getPort(), 0);
+    return shuffleServerInfo;
+  }
+
+  private static RssProtos.ShuffleServerId convertToShuffleServerId(
+      ShuffleServerInfo shuffleServerInfo) {
+    RssProtos.ShuffleServerId shuffleServerId =
+        RssProtos.ShuffleServerId.newBuilder()
+            .setId(shuffleServerInfo.getId())
+            .setIp(shuffleServerInfo.getHost())
+            .setPort(shuffleServerInfo.grpcPort)
+            .setNettyPort(shuffleServerInfo.nettyPort)
+            .build();
+    return shuffleServerId;
+  }
+
+  public static List<ShuffleServerInfo> 
fromProto(List<RssProtos.ShuffleServerId> servers) {
+    return servers.stream()
+        .map(server -> convertToShuffleServerId(server))
+        .collect(Collectors.toList());
+  }
+
+  public static List<RssProtos.ShuffleServerId> toProto(
+      List<ShuffleServerInfo> shuffleServerInfos) {
+    return shuffleServerInfos.stream()
+        .map(server -> convertToShuffleServerId(server))
+        .collect(Collectors.toList());
+  }
 }
diff --git 
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java
 
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java
index 419efb441..5e95bc009 100644
--- 
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java
+++ 
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java
@@ -47,6 +47,7 @@ public class RSSStageResubmitTest extends 
SparkIntegrationTestBase {
     dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key(), 
StorageType.MEMORY_LOCALFILE.name());
     dynamicConf.put(
         RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + 
RssClientConfig.RSS_RESUBMIT_STAGE, "true");
+    dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key(), 
StorageType.MEMORY_LOCALFILE.name());
     addDynamicConf(coordinatorConf, dynamicConf);
     createCoordinatorServer(coordinatorConf);
     ShuffleServerConf shuffleServerConf = getShuffleServerConf();
@@ -81,6 +82,8 @@ public class RSSStageResubmitTest extends 
SparkIntegrationTestBase {
 
   @Override
   public void updateSparkConfCustomer(SparkConf sparkConf) {
+    sparkConf.set(
+        RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + 
RssClientConfig.RSS_RESUBMIT_STAGE, "true");
     sparkConf.set("spark.task.maxFailures", String.valueOf(maxTaskFailures));
   }
 
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
index bd3c817db..997a8e0bc 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
@@ -19,10 +19,21 @@ package org.apache.uniffle.client.api;
 
 import java.io.Closeable;
 
+import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
 import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
+import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse;
 import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
 
 public interface ShuffleManagerClient extends Closeable {
   RssReportShuffleFetchFailureResponse reportShuffleFetchFailure(
       RssReportShuffleFetchFailureRequest request);
+
+  /**
+   * Gets the mapping between partitions and ShuffleServer from the 
ShuffleManager server
+   *
+   * @param req request
+   * @return RssPartitionToShuffleServerResponse
+   */
+  RssPartitionToShuffleServerResponse getPartitionToShufflerServer(
+      RssPartitionToShuffleServerRequest req);
 }
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
index 6b03e283a..56962c885 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
@@ -23,10 +23,13 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.uniffle.client.api.ShuffleManagerClient;
+import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
 import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
+import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse;
 import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
 import org.apache.uniffle.common.config.RssBaseConf;
 import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.proto.RssProtos;
 import org.apache.uniffle.proto.RssProtos.ReportShuffleFetchFailureRequest;
 import org.apache.uniffle.proto.RssProtos.ReportShuffleFetchFailureResponse;
 import org.apache.uniffle.proto.ShuffleManagerGrpc;
@@ -74,4 +77,15 @@ public class ShuffleManagerGrpcClient extends GrpcClient 
implements ShuffleManag
       throw new RssException(msg, e);
     }
   }
+
+  @Override
+  public RssPartitionToShuffleServerResponse getPartitionToShufflerServer(
+      RssPartitionToShuffleServerRequest req) {
+    RssProtos.PartitionToShuffleServerRequest protoRequest = req.toProto();
+    RssProtos.PartitionToShuffleServerResponse partitionToShufflerServer =
+        getBlockingStub().getPartitionToShufflerServer(protoRequest);
+    RssPartitionToShuffleServerResponse rssPartitionToShuffleServerResponse =
+        
RssPartitionToShuffleServerResponse.fromProto(partitionToShufflerServer);
+    return rssPartitionToShuffleServerResponse;
+  }
 }
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssPartitionToShuffleServerRequest.java
similarity index 58%
copy from 
internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
copy to 
internal-client/src/main/java/org/apache/uniffle/client/request/RssPartitionToShuffleServerRequest.java
index bd3c817db..62c3ec8f6 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssPartitionToShuffleServerRequest.java
@@ -15,14 +15,25 @@
  * limitations under the License.
  */
 
-package org.apache.uniffle.client.api;
+package org.apache.uniffle.client.request;
 
-import java.io.Closeable;
+import org.apache.uniffle.proto.RssProtos;
 
-import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
-import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
+public class RssPartitionToShuffleServerRequest {
+  private int shuffleId;
 
-public interface ShuffleManagerClient extends Closeable {
-  RssReportShuffleFetchFailureResponse reportShuffleFetchFailure(
-      RssReportShuffleFetchFailureRequest request);
+  public RssPartitionToShuffleServerRequest(int shuffleId) {
+    this.shuffleId = shuffleId;
+  }
+
+  public int getShuffleId() {
+    return shuffleId;
+  }
+
+  public RssProtos.PartitionToShuffleServerRequest toProto() {
+    RssProtos.PartitionToShuffleServerRequest.Builder builder =
+        RssProtos.PartitionToShuffleServerRequest.newBuilder();
+    builder.setShuffleId(shuffleId);
+    return builder.build();
+  }
 }
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssPartitionToShuffleServerResponse.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssPartitionToShuffleServerResponse.java
new file mode 100644
index 000000000..74d508eff
--- /dev/null
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssPartitionToShuffleServerResponse.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.client.response;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
+
+import org.apache.uniffle.common.RemoteStorageInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.proto.RssProtos;
+
+public class RssPartitionToShuffleServerResponse extends ClientResponse {
+
+  private Map<Integer, List<ShuffleServerInfo>> partitionToServers;
+  private Set<ShuffleServerInfo> shuffleServersForData;
+  private RemoteStorageInfo remoteStorageInfo;
+
+  public RssPartitionToShuffleServerResponse(
+      StatusCode statusCode,
+      String message,
+      Map<Integer, List<ShuffleServerInfo>> partitionToServers,
+      Set<ShuffleServerInfo> shuffleServersForData,
+      RemoteStorageInfo remoteStorageInfo) {
+    super(statusCode, message);
+    this.partitionToServers = partitionToServers;
+    this.remoteStorageInfo = remoteStorageInfo;
+    this.shuffleServersForData = shuffleServersForData;
+  }
+
+  public Map<Integer, List<ShuffleServerInfo>> getPartitionToServers() {
+    return partitionToServers;
+  }
+
+  public Set<ShuffleServerInfo> getShuffleServersForData() {
+    return shuffleServersForData;
+  }
+
+  public RemoteStorageInfo getRemoteStorageInfo() {
+    return remoteStorageInfo;
+  }
+
+  public static RssPartitionToShuffleServerResponse fromProto(
+      RssProtos.PartitionToShuffleServerResponse response) {
+    Map<Integer, RssProtos.GetShuffleServerListResponse> 
partitionToShuffleServerMap =
+        response.getPartitionToShuffleServerMap();
+    Map<Integer, List<ShuffleServerInfo>> rpcPartitionToShuffleServerInfos = 
Maps.newHashMap();
+    Set<Map.Entry<Integer, RssProtos.GetShuffleServerListResponse>> entries =
+        partitionToShuffleServerMap.entrySet();
+    for (Map.Entry<Integer, RssProtos.GetShuffleServerListResponse> entry : 
entries) {
+      Integer partitionId = entry.getKey();
+      List<ShuffleServerInfo> shuffleServerInfos = Lists.newArrayList();
+      List<? extends RssProtos.ShuffleServerIdOrBuilder> serversOrBuilderList =
+          entry.getValue().getServersOrBuilderList();
+      for (RssProtos.ShuffleServerIdOrBuilder shuffleServerIdOrBuilder : 
serversOrBuilderList) {
+        shuffleServerInfos.add(
+            new ShuffleServerInfo(
+                shuffleServerIdOrBuilder.getId(),
+                shuffleServerIdOrBuilder.getIp(),
+                shuffleServerIdOrBuilder.getPort(),
+                shuffleServerIdOrBuilder.getNettyPort()));
+      }
+
+      rpcPartitionToShuffleServerInfos.put(partitionId, shuffleServerInfos);
+    }
+    Set<ShuffleServerInfo> rpcShuffleServersForData = Sets.newHashSet();
+    for (List<ShuffleServerInfo> ssis : 
rpcPartitionToShuffleServerInfos.values()) {
+      rpcShuffleServersForData.addAll(ssis);
+    }
+    RssProtos.RemoteStorageInfo protoRemoteStorageInfo = 
response.getRemoteStorageInfo();
+    RemoteStorageInfo rpcRemoteStorageInfo =
+        new RemoteStorageInfo(
+            protoRemoteStorageInfo.getPath(), 
protoRemoteStorageInfo.getConfItemsMap());
+    return new RssPartitionToShuffleServerResponse(
+        StatusCode.valueOf(response.getStatus().name()),
+        response.getMsg(),
+        rpcPartitionToShuffleServerInfos,
+        rpcShuffleServersForData,
+        rpcRemoteStorageInfo);
+  }
+}
diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto
index 1c8dfdd24..31e29bdae 100644
--- a/proto/src/main/proto/Rss.proto
+++ b/proto/src/main/proto/Rss.proto
@@ -497,6 +497,8 @@ message CancelDecommissionResponse {
 // per application.
 service ShuffleManager {
   rpc reportShuffleFetchFailure (ReportShuffleFetchFailureRequest) returns 
(ReportShuffleFetchFailureResponse);
+  // Gets the mapping between partitions and ShuffleServer from the 
ShuffleManager server
+  rpc getPartitionToShufflerServer(PartitionToShuffleServerRequest) returns 
(PartitionToShuffleServerResponse);
 }
 
 message ReportShuffleFetchFailureRequest {
@@ -516,3 +518,19 @@ message ReportShuffleFetchFailureResponse {
   bool reSubmitWholeStage = 2;
   string msg = 3;
 }
+
+message PartitionToShuffleServerRequest {
+  int32 shuffleId = 2;
+}
+
+message PartitionToShuffleServerResponse {
+  StatusCode status = 1;
+  map<int32,GetShuffleServerListResponse> partitionToShuffleServer = 2;
+  RemoteStorageInfo remote_storage_info = 3;
+  string msg = 4;
+}
+
+message RemoteStorageInfo{
+  string path = 1;
+  map<string, string> confItems = 2;
+}

Reply via email to