This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 4805d1335 [#2636] feat(spark): Cache shuffle handle info for reader to 
reduce RPC cost when partition reassign is enabled (#2637)
4805d1335 is described below

commit 4805d1335aae06fa699ee3c2f17d94f599f3cb2c
Author: Junfan Zhang <[email protected]>
AuthorDate: Tue Sep 30 14:22:27 2025 +0800

    [#2636] feat(spark): Cache shuffle handle info for reader to reduce RPC 
cost when partition reassign is enabled (#2637)
    
    ### What changes were proposed in this pull request?
    
    This PR is to introduce the cache mechanism to cache the read shuffle 
handle info to reduce the RPC cost and driver the GC pressure when the 
partition reassign is enabled
    
    ### Why are the changes needed?
    
    for #2636 .
    
    From the cluster spark jobs, I found some tasks failed on the failure of 
RPC of getting shuffle handle from the driver side when the partition reassign 
is enabled. This is the first step to optimize shuffle info getting for the 
reader side.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes.
    
    `rss.client.read.shuffleHandleCacheEnabled=false`
    
    ### How was this patch tested?
    
    Existing tests
---
 .../org/apache/spark/shuffle/RssSparkConfig.java   |  6 +++
 .../shuffle/manager/RssShuffleManagerBase.java     | 42 +++++++++++++++++++
 .../apache/spark/shuffle/RssShuffleManager.java    | 27 +++++++-----
 .../spark/shuffle/RssShuffleManagerTest.java       | 48 ++++++++++++++++++++++
 .../test/PartitionBlockDataReassignBasicTest.java  |  2 +
 5 files changed, 115 insertions(+), 10 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index 4af6cd421..8dc17b70f 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -39,6 +39,12 @@ import org.apache.uniffle.common.config.RssConf;
 
 public class RssSparkConfig {
 
+  public static final ConfigOption<Boolean> 
RSS_READ_SHUFFLE_HANDLE_CACHE_ENABLED =
+      ConfigOptions.key("rss.client.read.shuffleHandleCacheEnabled")
+          .booleanType()
+          .defaultValue(false)
+          .withDescription("Whether or not to read shuffle handle cache 
enabled");
+
   public static final ConfigOption<Boolean> 
RSS_READ_OVERLAPPING_DECOMPRESSION_ENABLED =
       ConfigOptions.key("rss.client.read.overlappingDecompressionEnable")
           .booleanType()
diff --git 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
index ad1c4dd9a..be5e56a05 100644
--- 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
+++ 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
@@ -109,6 +109,7 @@ import org.apache.uniffle.shuffle.ShuffleIdMappingManager;
 import static org.apache.spark.launcher.SparkLauncher.EXECUTOR_CORES;
 import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED;
 import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM;
+import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_READ_SHUFFLE_HANDLE_CACHE_ENABLED;
 import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;
 import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED;
 import static org.apache.spark.shuffle.RssSparkShuffleUtils.isSparkUIEnabled;
@@ -184,6 +185,11 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
   private AtomicBoolean reassignTriggeredOnBlockSendFailure = new 
AtomicBoolean(false);
   private AtomicBoolean reassignTriggeredOnStageRetry = new 
AtomicBoolean(false);
 
+  // cache to shuffle handle info to reduce the RPC cost when getting the 
reader.
+  // this is only valid when the partition reassign is enabled.
+  protected final boolean readShuffleHandleCacheEnabled;
+  private Map<Integer, ShuffleHandleInfo> readShuffleHandleCache = 
Maps.newConcurrentMap();
+
   private boolean isDriver = false;
 
   public RssShuffleManagerBase(SparkConf conf, boolean isDriver) {
@@ -373,6 +379,8 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
     this.shuffleHandleInfoManager = new ShuffleHandleInfoManager();
     this.rssStageResubmitManager = new RssStageResubmitManager();
     this.shuffleIdMappingManager = new ShuffleIdMappingManager();
+
+    this.readShuffleHandleCacheEnabled = 
rssConf.get(RSS_READ_SHUFFLE_HANDLE_CACHE_ENABLED);
   }
 
   @VisibleForTesting
@@ -424,6 +432,7 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
     this.shuffleHandleInfoManager = new ShuffleHandleInfoManager();
     this.rssStageResubmitManager = new RssStageResubmitManager();
     this.shuffleIdMappingManager = new ShuffleIdMappingManager();
+    this.readShuffleHandleCacheEnabled = 
rssConf.get(RSS_READ_SHUFFLE_HANDLE_CACHE_ENABLED);
   }
 
   public BlockIdManager getBlockIdManager() {
@@ -444,6 +453,9 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
       if (blockIdManager != null) {
         blockIdManager.remove(shuffleId);
       }
+      if (readShuffleHandleCache != null) {
+        readShuffleHandleCache.remove(shuffleId);
+      }
       if (SparkEnv.get().executorId().equals("driver")) {
         shuffleWriteClient.unregisterShuffle(getAppId(), shuffleId);
         shuffleIdToPartitionNum.remove(shuffleId);
@@ -1574,4 +1586,34 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
     }
     return new CompletableFuture<>();
   }
+
+  // only for tests
+  public void clearShuffleHandleCache() {
+    readShuffleHandleCache.clear();
+  }
+
+  public ShuffleHandleInfo getOrFetchShuffleHandle(
+      int shuffleId, Supplier<ShuffleHandleInfo> func) {
+    ShuffleHandleInfo handle =
+        readShuffleHandleCache.computeIfAbsent(
+            shuffleId,
+            integer -> {
+              long start = System.currentTimeMillis();
+              try {
+                return func.get();
+              } catch (Exception e) {
+                LOG.error("Fail to get the shuffle handle for {}", shuffleId, 
e);
+              } finally {
+                LOG.info(
+                    "Gotten the shuffle handle for shuffle: {} that costs {} 
ms",
+                    shuffleId,
+                    System.currentTimeMillis() - start);
+              }
+              return null;
+            });
+    if (handle == null) {
+      throw new RssException("Shuffle handle id " + shuffleId + " not found");
+    }
+    return handle;
+  }
 }
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 58349b56c..e4b180af1 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -23,6 +23,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Supplier;
 import java.util.stream.Collectors;
 
 import scala.Tuple2;
@@ -365,16 +366,22 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     final int partitionNum = 
rssShuffleHandle.getDependency().partitioner().numPartitions();
     int shuffleId = rssShuffleHandle.getShuffleId();
     ShuffleHandleInfo shuffleHandleInfo;
-    if (shuffleManagerRpcServiceEnabled && 
rssStageRetryForWriteFailureEnabled) {
-      // In Stage Retry mode, Get the ShuffleServer list from the Driver based 
on the shuffleId.
-      shuffleHandleInfo =
-          getRemoteShuffleHandleInfoWithStageRetry(
-              context.stageId(), context.stageAttemptNumber(), shuffleId, 
false);
-    } else if (shuffleManagerRpcServiceEnabled && partitionReassignEnabled) {
-      // In Block Retry mode, Get the ShuffleServer list from the Driver based 
on the shuffleId.
-      shuffleHandleInfo =
-          getRemoteShuffleHandleInfoWithBlockRetry(
-              context.stageId(), context.stageAttemptNumber(), shuffleId, 
false);
+
+    if (shuffleManagerRpcServiceEnabled
+        && (rssStageRetryForWriteFailureEnabled || partitionReassignEnabled)) {
+      Supplier<ShuffleHandleInfo> func =
+          rssStageRetryForWriteFailureEnabled
+              ? () ->
+                  getRemoteShuffleHandleInfoWithStageRetry(
+                      context.stageId(), context.stageAttemptNumber(), 
shuffleId, false)
+              : () ->
+                  getRemoteShuffleHandleInfoWithBlockRetry(
+                      context.stageId(), context.stageAttemptNumber(), 
shuffleId, false);
+      if (readShuffleHandleCacheEnabled) {
+        shuffleHandleInfo = super.getOrFetchShuffleHandle(shuffleId, func);
+      } else {
+        shuffleHandleInfo = func.get();
+      }
     } else {
       shuffleHandleInfo =
           new SimpleShuffleHandleInfo(
diff --git 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
index 877bf6e62..df6255c5f 100644
--- 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
+++ 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
@@ -17,14 +17,22 @@
 
 package org.apache.spark.shuffle;
 
+import java.util.Collections;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Supplier;
+
 import org.apache.spark.Partitioner;
 import org.apache.spark.ShuffleDependency;
 import org.apache.spark.SparkConf;
+import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
 import org.apache.spark.sql.internal.SQLConf;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.params.ParameterizedTest;
 import org.junit.jupiter.params.provider.ValueSource;
 
+import org.apache.uniffle.client.impl.FailedBlockSendTracker;
 import org.apache.uniffle.client.util.RssClientConfig;
 import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.ShuffleDataDistributionType;
@@ -33,6 +41,7 @@ import org.apache.uniffle.common.config.RssClientConf;
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.rpc.StatusCode;
 import org.apache.uniffle.common.util.BlockIdLayout;
+import org.apache.uniffle.common.util.JavaUtils;
 import org.apache.uniffle.storage.util.StorageType;
 
 import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;
@@ -42,6 +51,7 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertNull;
 import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
 import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.junit.jupiter.api.Assertions.fail;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
@@ -252,4 +262,42 @@ public class RssShuffleManagerTest extends 
RssShuffleManagerTestBase {
     conf.set("spark.driver.host", "localhost");
     return conf;
   }
+
+  @Test
+  public void testReadCacheShuffleInfo() {
+    SparkConf conf = new SparkConf();
+    conf.setAppName("testApp")
+        .setMaster("local[2]")
+        .set(RssSparkConfig.RSS_TEST_FLAG.key(), "true")
+        .set(RssSparkConfig.RSS_TEST_MODE_ENABLE.key(), "true")
+        .set(RssSparkConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS.key(), "10000")
+        .set(RssSparkConfig.RSS_CLIENT_RETRY_MAX.key(), "10")
+        .set(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS.key(), "1000")
+        .set(RssSparkConfig.RSS_STORAGE_TYPE.key(), 
StorageType.LOCALFILE.name())
+        .set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), 
"127.0.0.1:12345,127.0.0.1:12346");
+    Map<String, Set<Long>> successBlocks = JavaUtils.newConcurrentMap();
+    Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker = 
JavaUtils.newConcurrentMap();
+    RssShuffleManager manager =
+        TestUtils.createShuffleManager(
+            conf, false, null, successBlocks, taskToFailedBlockSendTracker);
+
+    // case1: legal fetch and cache
+    Supplier<ShuffleHandleInfo> func1 =
+        () ->
+            new SimpleShuffleHandleInfo(
+                1, Collections.emptyMap(), 
RemoteStorageInfo.EMPTY_REMOTE_STORAGE);
+    ShuffleHandleInfo handle1 = manager.getOrFetchShuffleHandle(1, func1);
+    ShuffleHandleInfo handle2 = manager.getOrFetchShuffleHandle(1, func1);
+    assertEquals(handle1, handle2);
+
+    // case2: illegal fetch
+    manager.clearShuffleHandleCache();
+    Supplier<ShuffleHandleInfo> func2 = () -> null;
+    try {
+      ShuffleHandleInfo handle3 = manager.getOrFetchShuffleHandle(1, func2);
+      fail();
+    } catch (Exception e) {
+      // ignore
+    }
+  }
 }
diff --git 
a/integration-test/spark3/src/test/java/org/apache/uniffle/test/PartitionBlockDataReassignBasicTest.java
 
b/integration-test/spark3/src/test/java/org/apache/uniffle/test/PartitionBlockDataReassignBasicTest.java
index c0fbb3a22..9adc4b315 100644
--- 
a/integration-test/spark3/src/test/java/org/apache/uniffle/test/PartitionBlockDataReassignBasicTest.java
+++ 
b/integration-test/spark3/src/test/java/org/apache/uniffle/test/PartitionBlockDataReassignBasicTest.java
@@ -35,6 +35,7 @@ import org.apache.uniffle.server.ShuffleServerConf;
 import org.apache.uniffle.server.buffer.ShuffleBufferManager;
 import org.apache.uniffle.storage.util.StorageType;
 
+import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_READ_SHUFFLE_HANDLE_CACHE_ENABLED;
 import static 
org.apache.uniffle.client.util.RssClientConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER;
 import static 
org.apache.uniffle.client.util.RssClientConfig.RSS_CLIENT_RETRY_MAX;
 import static 
org.apache.uniffle.common.config.RssClientConf.RSS_CLIENT_REASSIGN_ENABLED;
@@ -99,6 +100,7 @@ public class PartitionBlockDataReassignBasicTest extends 
SparkSQLTest {
         "spark." + RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER,
         String.valueOf(grpcShuffleServers.size()));
     sparkConf.set("spark." + RSS_CLIENT_REASSIGN_ENABLED.key(), "true");
+    sparkConf.set("spark." + RSS_READ_SHUFFLE_HANDLE_CACHE_ENABLED.key(), 
"true");
   }
 
   @Override

Reply via email to