This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 0facb7be1 [#2568] feat(spark): Use space-efficient protobuf for 
`MutableShuffleHandleInfo` to reduce RPC memory overhead (#2578)
0facb7be1 is described below

commit 0facb7be1913c088eab5c5d9970ab0766a884a99
Author: Junfan Zhang <zus...@apache.org>
AuthorDate: Tue Aug 19 14:24:06 2025 +0800

    [#2568] feat(spark): Use space-efficient protobuf for 
`MutableShuffleHandleInfo` to reduce RPC memory overhead (#2578)
    
    ### What changes were proposed in this pull request?
    
    This PR uses a space-efficient protobuf data structure to store the 
partitions-to-servers mapping, thereby reducing the RPC cost.
    
    ### Why are the changes needed?
    
    This is the part of PR for the #2568.
    
    In large-scale Spark jobs, the number of partitions can reach up to 20K, 
whereas the number of assigned shuffle servers remains smaller than the total 
number of nodes in the Uniffle cluster.
    Prior to this PR, both the driver and the client (when reassignment was 
enabled) required substantial memory for RPC transfers, which could 
significantly increase the frequency of driver garbage collection.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Unit tests.
---
 .../shuffle/handle/MutableShuffleHandleInfo.java   | 108 +++++++++++++++------
 .../handle/MutableShuffleHandleInfoTest.java       |  72 ++++++++++++++
 .../apache/uniffle/common/ShuffleServerInfo.java   |   2 +-
 proto/src/main/proto/Rss.proto                     |  15 ++-
 4 files changed, 164 insertions(+), 33 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java
index 8e0fe9d35..0e22c017c 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java
@@ -19,6 +19,7 @@ package org.apache.spark.shuffle.handle;
 
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
@@ -30,6 +31,8 @@ import java.util.stream.Collectors;
 import com.google.common.annotations.VisibleForTesting;
 import org.apache.commons.collections4.CollectionUtils;
 import org.apache.commons.lang3.StringUtils;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.commons.lang3.tuple.Triple;
 import org.apache.spark.TaskContext;
 import org.apache.spark.shuffle.handle.split.PartitionSplitInfo;
 import org.slf4j.Logger;
@@ -208,10 +211,15 @@ public class MutableShuffleHandleInfo extends 
ShuffleHandleInfoBase {
       PartitionSplitInfo splitInfo = this.getPartitionSplitInfo(partitionId);
       for (Map.Entry<Integer, List<ShuffleServerInfo>> replicaServerEntry :
           replicaServers.entrySet()) {
-        ShuffleServerInfo candidate;
-        int candidateSize = replicaServerEntry.getValue().size();
-        // Use the last one for each replica writing
-        candidate = replicaServerEntry.getValue().get(candidateSize - 1);
+
+        // For normal partition reassignment, the latest replacement shuffle 
server is always used.
+        Optional<ShuffleServerInfo> candidateOptional =
+            replicaServerEntry.getValue().stream()
+                .filter(x -> 
!excludedServerToReplacements.containsKey(x.getId()))
+                .findFirst();
+        // Get the unexcluded server for each replica writing
+        ShuffleServerInfo candidate =
+            
replicaServerEntry.getValue().get(replicaServerEntry.getValue().size() - 1);
 
         long taskAttemptId =
             Optional.ofNullable(TaskContext.get()).map(x -> 
x.taskAttemptId()).orElse(-1L);
@@ -301,27 +309,42 @@ public class MutableShuffleHandleInfo extends 
ShuffleHandleInfoBase {
 
   public static RssProtos.MutableShuffleHandleInfo 
toProto(MutableShuffleHandleInfo handleInfo) {
     synchronized (handleInfo) {
-      Map<Integer, RssProtos.PartitionReplicaServers> partitionToServers = new 
HashMap<>();
+      // value: (PartitionId, ReplicaIndex, SequenceIndex)
+      Map<ShuffleServerInfo, List<Triple<Integer, Integer, Integer>>> 
serverToPartitions =
+          new HashMap<>();
       for (Map.Entry<Integer, Map<Integer, List<ShuffleServerInfo>>> entry :
           handleInfo.partitionReplicaAssignedServers.entrySet()) {
         int partitionId = entry.getKey();
-
-        Map<Integer, RssProtos.ReplicaServersItem> replicaServersProto = new 
HashMap<>();
         Map<Integer, List<ShuffleServerInfo>> replicaServers = 
entry.getValue();
         for (Map.Entry<Integer, List<ShuffleServerInfo>> replicaServerEntry :
             replicaServers.entrySet()) {
-          RssProtos.ReplicaServersItem item =
-              RssProtos.ReplicaServersItem.newBuilder()
-                  
.addAllServerId(ShuffleServerInfo.toProto(replicaServerEntry.getValue()))
-                  .build();
-          replicaServersProto.put(replicaServerEntry.getKey(), item);
+          int replicaIndex = replicaServerEntry.getKey();
+          List<ShuffleServerInfo> servers = replicaServerEntry.getValue();
+          for (int i = 0; i < servers.size(); i++) {
+            ShuffleServerInfo server = servers.get(i);
+            serverToPartitions
+                .computeIfAbsent(server, x -> new ArrayList<>())
+                .add(Triple.of(partitionId, replicaIndex, i));
+          }
         }
-
-        RssProtos.PartitionReplicaServers partitionReplicaServerProto =
-            RssProtos.PartitionReplicaServers.newBuilder()
-                .putAllReplicaServers(replicaServersProto)
-                .build();
-        partitionToServers.put(partitionId, partitionReplicaServerProto);
+      }
+      List<RssProtos.ServerToPartitionsItem> protoServerToPartitionsItems = 
new ArrayList<>();
+      for (Map.Entry<ShuffleServerInfo, List<Triple<Integer, Integer, 
Integer>>> entry :
+          serverToPartitions.entrySet()) {
+        List<RssProtos.PartitionReplicaItem> replicaItems = new ArrayList<>();
+        for (Triple<Integer, Integer, Integer> partitionReplica : 
entry.getValue()) {
+          replicaItems.add(
+              RssProtos.PartitionReplicaItem.newBuilder()
+                  .setPartitionId(partitionReplica.getLeft())
+                  .setReplicaIndex(partitionReplica.getMiddle())
+                  .setSequenceIndex(partitionReplica.getRight())
+                  .build());
+        }
+        protoServerToPartitionsItems.add(
+            RssProtos.ServerToPartitionsItem.newBuilder()
+                
.setServerId(ShuffleServerInfo.convertToShuffleServerId(entry.getKey()))
+                .addAllPartitionToReplicaItems(replicaItems)
+                .build());
       }
 
       Map<String, RssProtos.ReplacementServers> excludeToReplacements = new 
HashMap<>();
@@ -347,8 +370,8 @@ public class MutableShuffleHandleInfo extends 
ShuffleHandleInfoBase {
                       .setPath(handleInfo.remoteStorage.getPath())
                       .putAllConfItems(handleInfo.remoteStorage.getConfItems())
                       .build())
-              .putAllPartitionToServers(partitionToServers)
               .putAllExcludedServerToReplacements(excludeToReplacements)
+              .addAllServerToPartitionItem(protoServerToPartitionsItems)
               .setPartitionSplitMode(mode)
               
.addAllSplitPartitionId(handleInfo.excludedServerForPartitionToReplacements.keySet())
               .build();
@@ -360,18 +383,41 @@ public class MutableShuffleHandleInfo extends 
ShuffleHandleInfoBase {
     if (handleProto == null) {
       return null;
     }
-    Map<Integer, Map<Integer, List<ShuffleServerInfo>>> partitionToServers = 
new HashMap<>();
-    for (Map.Entry<Integer, RssProtos.PartitionReplicaServers> entry :
-        handleProto.getPartitionToServersMap().entrySet()) {
-      Map<Integer, List<ShuffleServerInfo>> replicaServers =
-          partitionToServers.computeIfAbsent(entry.getKey(), x -> new 
HashMap<>());
-      for (Map.Entry<Integer, RssProtos.ReplicaServersItem> serverEntry :
-          entry.getValue().getReplicaServersMap().entrySet()) {
-        int replicaIdx = serverEntry.getKey();
-        List<ShuffleServerInfo> shuffleServerInfos =
-            
ShuffleServerInfo.fromProto(serverEntry.getValue().getServerIdList());
-        replicaServers.put(replicaIdx, shuffleServerInfos);
+    Map<Integer, Map<Integer, List<Pair<Integer, ShuffleServerInfo>>>> 
partitionToServers =
+        new HashMap<>();
+    for (RssProtos.ServerToPartitionsItem item : 
handleProto.getServerToPartitionItemList()) {
+      ShuffleServerInfo shuffleServerInfo =
+          ShuffleServerInfo.convertFromShuffleServerId(item.getServerId());
+      for (RssProtos.PartitionReplicaItem partitionReplicaItem :
+          item.getPartitionToReplicaItemsList()) {
+        int partitionId = partitionReplicaItem.getPartitionId();
+        int replicaIndex = partitionReplicaItem.getReplicaIndex();
+        int sequenceIndex = partitionReplicaItem.getSequenceIndex();
+        partitionToServers
+            .computeIfAbsent(partitionId, k -> new HashMap<>())
+            .computeIfAbsent(replicaIndex, k -> new ArrayList<>())
+            .add(Pair.of(sequenceIndex, shuffleServerInfo));
+      }
+    }
+    Map<Integer, Map<Integer, List<ShuffleServerInfo>>> 
partitionReplicaAssignedServers =
+        new HashMap<>();
+    for (Map.Entry<Integer, Map<Integer, List<Pair<Integer, 
ShuffleServerInfo>>>> partitionEntry :
+        partitionToServers.entrySet()) {
+      int partitionId = partitionEntry.getKey();
+      Map<Integer, List<Pair<Integer, ShuffleServerInfo>>> replicaMap = 
partitionEntry.getValue();
+      Map<Integer, List<ShuffleServerInfo>> replicaAssignedMap = new 
HashMap<>();
+      for (Map.Entry<Integer, List<Pair<Integer, ShuffleServerInfo>>> 
replicaEntry :
+          replicaMap.entrySet()) {
+        int replicaIndex = replicaEntry.getKey();
+        List<Pair<Integer, ShuffleServerInfo>> pairs = replicaEntry.getValue();
+
+        pairs.sort(Comparator.comparingInt(Pair::getLeft));
+
+        List<ShuffleServerInfo> servers =
+            pairs.stream().map(Pair::getRight).collect(Collectors.toList());
+        replicaAssignedMap.put(replicaIndex, servers);
       }
+      partitionReplicaAssignedServers.put(partitionId, replicaAssignedMap);
     }
 
     Map<String, Set<ShuffleServerInfo>> excludeToReplacements = new 
HashMap<>();
@@ -393,7 +439,7 @@ public class MutableShuffleHandleInfo extends 
ShuffleHandleInfoBase {
             handleProto.getRemoteStorageInfo().getConfItemsMap());
     MutableShuffleHandleInfo handle =
         new MutableShuffleHandleInfo(handleProto.getShuffleId(), 
remoteStorageInfo);
-    handle.partitionReplicaAssignedServers = partitionToServers;
+    handle.partitionReplicaAssignedServers = partitionReplicaAssignedServers;
     handle.partitionSplitMode =
         handleProto.getPartitionSplitMode() == 
RssProtos.PartitionSplitMode.LOAD_BALANCE
             ? PartitionSplitMode.LOAD_BALANCE
diff --git 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfoTest.java
 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfoTest.java
index 3a24f301b..cf75152e7 100644
--- 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfoTest.java
+++ 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfoTest.java
@@ -25,12 +25,14 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 
+import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Sets;
 import org.junit.jupiter.api.Test;
 
 import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
 import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.proto.RssProtos;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -41,6 +43,76 @@ public class MutableShuffleHandleInfoTest {
     return new ShuffleServerInfo(id, id, 1);
   }
 
+  private static boolean mapsIsEqual(
+      Map<Integer, List<ShuffleServerInfo>> map1, Map<Integer, 
List<ShuffleServerInfo>> map2) {
+
+    if (map1 == map2) {
+      return true;
+    }
+    if (map1 == null || map2 == null) {
+      return false;
+    }
+    if (map1.size() != map2.size()) {
+      return false;
+    }
+    for (Map.Entry<Integer, List<ShuffleServerInfo>> entry : map1.entrySet()) {
+      List<ShuffleServerInfo> list1 = entry.getValue();
+      List<ShuffleServerInfo> list2 = map2.get(entry.getKey());
+      if (list2 == null || list1.size() != list2.size()) {
+        return false;
+      }
+      // all the elements should be equal with the same index
+      for (int i = 0; i < list1.size(); i++) {
+        if (!list2.get(i).equals(list1.get(i))) {
+          return false;
+        }
+      }
+    }
+    return true;
+  }
+
+  @Test
+  public void testSerializationWithProtobuf() {
+    Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
+    partitionToServers.put(1, Arrays.asList(createFakeServerInfo("a"), 
createFakeServerInfo("b")));
+    partitionToServers.put(2, Arrays.asList(createFakeServerInfo("c")));
+
+    // case1: with single replica
+    MutableShuffleHandleInfo handleInfo =
+        new MutableShuffleHandleInfo(1, partitionToServers, new 
RemoteStorageInfo(""));
+
+    RssProtos.MutableShuffleHandleInfo serialized = 
MutableShuffleHandleInfo.toProto(handleInfo);
+    MutableShuffleHandleInfo deserialized = 
MutableShuffleHandleInfo.fromProto(serialized);
+
+    assert (mapsIsEqual(deserialized.getAllPartitionServersForReader(), 
partitionToServers));
+
+    // case2: with multi replicas
+    Map<Integer, Map<Integer, List<ShuffleServerInfo>>> 
partitionWithReplicaServers =
+        new HashMap<>();
+    partitionWithReplicaServers.put(
+        1,
+        ImmutableMap.of(
+            0, Arrays.asList(createFakeServerInfo("a"), 
createFakeServerInfo("b")),
+            1, Arrays.asList(createFakeServerInfo("c"), 
createFakeServerInfo("d"))));
+    partitionWithReplicaServers.put(
+        2, ImmutableMap.of(0, Arrays.asList(createFakeServerInfo("c"))));
+    MutableShuffleHandleInfo replicaHandleInfo =
+        new MutableShuffleHandleInfo(2, new RemoteStorageInfo(""), 
partitionWithReplicaServers);
+    serialized = MutableShuffleHandleInfo.toProto(replicaHandleInfo);
+    deserialized = MutableShuffleHandleInfo.fromProto(serialized);
+
+    assert (mapsIsEqual(
+        deserialized.getAllPartitionServersForReader(),
+        ImmutableMap.of(
+            1,
+                Arrays.asList(
+                    createFakeServerInfo("a"),
+                    createFakeServerInfo("b"),
+                    createFakeServerInfo("c"),
+                    createFakeServerInfo("d")),
+            2, Arrays.asList(createFakeServerInfo("c")))));
+  }
+
   @Test
   public void testUpdateAssignment() {
     Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
diff --git 
a/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java 
b/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java
index 6ebea551a..300a7369c 100644
--- a/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java
+++ b/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java
@@ -114,7 +114,7 @@ public class ShuffleServerInfo implements Serializable {
     }
   }
 
-  private static ShuffleServerInfo convertFromShuffleServerId(
+  public static ShuffleServerInfo convertFromShuffleServerId(
       RssProtos.ShuffleServerId shuffleServerId) {
     ShuffleServerInfo shuffleServerInfo =
         new ShuffleServerInfo(
diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto
index be1589952..1ad6f2a5d 100644
--- a/proto/src/main/proto/Rss.proto
+++ b/proto/src/main/proto/Rss.proto
@@ -685,12 +685,25 @@ message StageAttemptShuffleHandleInfo {
 
 message MutableShuffleHandleInfo {
   int32 shuffleId = 1;
-  map<int32, PartitionReplicaServers> partitionToServers = 2;
+  repeated ServerToPartitionsItem serverToPartitionItem = 2;
   RemoteStorageInfo remoteStorageInfo = 3;
 
   map<string, ReplacementServers> excludedServerToReplacements = 4;
   repeated int32 splitPartitionId = 5;
   PartitionSplitMode partitionSplitMode = 6;
+
+}
+
+message ServerToPartitionsItem {
+  ShuffleServerId serverId = 1;
+  repeated PartitionReplicaItem partitionToReplicaItems = 2;
+}
+
+message PartitionReplicaItem {
+  int32 partitionId = 1;
+  int32 replicaIndex = 2;
+  // the index of assigned servers list
+  int32 sequenceIndex = 3;
 }
 
 enum PartitionSplitMode {

Reply via email to