This is an automated email from the ASF dual-hosted git repository.
zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 4805d1335 [#2636] feat(spark): Cache shuffle handle info for reader to
reduce RPC cost when partition reassign is enabled (#2637)
4805d1335 is described below
commit 4805d1335aae06fa699ee3c2f17d94f599f3cb2c
Author: Junfan Zhang <[email protected]>
AuthorDate: Tue Sep 30 14:22:27 2025 +0800
[#2636] feat(spark): Cache shuffle handle info for reader to reduce RPC
cost when partition reassign is enabled (#2637)
### What changes were proposed in this pull request?
This PR is to introduce the cache mechanism to cache the read shuffle
handle info to reduce the RPC cost and driver the GC pressure when the
partition reassign is enabled
### Why are the changes needed?
for #2636 .
From the cluster spark jobs, I found some tasks failed on the failure of
RPC of getting shuffle handle from the driver side when the partition reassign
is enabled. This is the first step to optimize shuffle info getting for the
reader side.
### Does this PR introduce _any_ user-facing change?
Yes.
`rss.client.read.shuffleHandleCacheEnabled=false`
### How was this patch tested?
Existing tests
---
.../org/apache/spark/shuffle/RssSparkConfig.java | 6 +++
.../shuffle/manager/RssShuffleManagerBase.java | 42 +++++++++++++++++++
.../apache/spark/shuffle/RssShuffleManager.java | 27 +++++++-----
.../spark/shuffle/RssShuffleManagerTest.java | 48 ++++++++++++++++++++++
.../test/PartitionBlockDataReassignBasicTest.java | 2 +
5 files changed, 115 insertions(+), 10 deletions(-)
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index 4af6cd421..8dc17b70f 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -39,6 +39,12 @@ import org.apache.uniffle.common.config.RssConf;
public class RssSparkConfig {
+ public static final ConfigOption<Boolean>
RSS_READ_SHUFFLE_HANDLE_CACHE_ENABLED =
+ ConfigOptions.key("rss.client.read.shuffleHandleCacheEnabled")
+ .booleanType()
+ .defaultValue(false)
+ .withDescription("Whether or not to read shuffle handle cache
enabled");
+
public static final ConfigOption<Boolean>
RSS_READ_OVERLAPPING_DECOMPRESSION_ENABLED =
ConfigOptions.key("rss.client.read.overlappingDecompressionEnable")
.booleanType()
diff --git
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
index ad1c4dd9a..be5e56a05 100644
---
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
+++
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
@@ -109,6 +109,7 @@ import org.apache.uniffle.shuffle.ShuffleIdMappingManager;
import static org.apache.spark.launcher.SparkLauncher.EXECUTOR_CORES;
import static
org.apache.spark.shuffle.RssSparkConfig.RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED;
import static
org.apache.spark.shuffle.RssSparkConfig.RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM;
+import static
org.apache.spark.shuffle.RssSparkConfig.RSS_READ_SHUFFLE_HANDLE_CACHE_ENABLED;
import static
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;
import static
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED;
import static org.apache.spark.shuffle.RssSparkShuffleUtils.isSparkUIEnabled;
@@ -184,6 +185,11 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
private AtomicBoolean reassignTriggeredOnBlockSendFailure = new
AtomicBoolean(false);
private AtomicBoolean reassignTriggeredOnStageRetry = new
AtomicBoolean(false);
+ // cache to shuffle handle info to reduce the RPC cost when getting the
reader.
+ // this is only valid when the partition reassign is enabled.
+ protected final boolean readShuffleHandleCacheEnabled;
+ private Map<Integer, ShuffleHandleInfo> readShuffleHandleCache =
Maps.newConcurrentMap();
+
private boolean isDriver = false;
public RssShuffleManagerBase(SparkConf conf, boolean isDriver) {
@@ -373,6 +379,8 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
this.shuffleHandleInfoManager = new ShuffleHandleInfoManager();
this.rssStageResubmitManager = new RssStageResubmitManager();
this.shuffleIdMappingManager = new ShuffleIdMappingManager();
+
+ this.readShuffleHandleCacheEnabled =
rssConf.get(RSS_READ_SHUFFLE_HANDLE_CACHE_ENABLED);
}
@VisibleForTesting
@@ -424,6 +432,7 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
this.shuffleHandleInfoManager = new ShuffleHandleInfoManager();
this.rssStageResubmitManager = new RssStageResubmitManager();
this.shuffleIdMappingManager = new ShuffleIdMappingManager();
+ this.readShuffleHandleCacheEnabled =
rssConf.get(RSS_READ_SHUFFLE_HANDLE_CACHE_ENABLED);
}
public BlockIdManager getBlockIdManager() {
@@ -444,6 +453,9 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
if (blockIdManager != null) {
blockIdManager.remove(shuffleId);
}
+ if (readShuffleHandleCache != null) {
+ readShuffleHandleCache.remove(shuffleId);
+ }
if (SparkEnv.get().executorId().equals("driver")) {
shuffleWriteClient.unregisterShuffle(getAppId(), shuffleId);
shuffleIdToPartitionNum.remove(shuffleId);
@@ -1574,4 +1586,34 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
}
return new CompletableFuture<>();
}
+
+ // only for tests
+ public void clearShuffleHandleCache() {
+ readShuffleHandleCache.clear();
+ }
+
+ public ShuffleHandleInfo getOrFetchShuffleHandle(
+ int shuffleId, Supplier<ShuffleHandleInfo> func) {
+ ShuffleHandleInfo handle =
+ readShuffleHandleCache.computeIfAbsent(
+ shuffleId,
+ integer -> {
+ long start = System.currentTimeMillis();
+ try {
+ return func.get();
+ } catch (Exception e) {
+ LOG.error("Fail to get the shuffle handle for {}", shuffleId,
e);
+ } finally {
+ LOG.info(
+ "Gotten the shuffle handle for shuffle: {} that costs {}
ms",
+ shuffleId,
+ System.currentTimeMillis() - start);
+ }
+ return null;
+ });
+ if (handle == null) {
+ throw new RssException("Shuffle handle id " + shuffleId + " not found");
+ }
+ return handle;
+ }
}
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 58349b56c..e4b180af1 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -23,6 +23,7 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Supplier;
import java.util.stream.Collectors;
import scala.Tuple2;
@@ -365,16 +366,22 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
final int partitionNum =
rssShuffleHandle.getDependency().partitioner().numPartitions();
int shuffleId = rssShuffleHandle.getShuffleId();
ShuffleHandleInfo shuffleHandleInfo;
- if (shuffleManagerRpcServiceEnabled &&
rssStageRetryForWriteFailureEnabled) {
- // In Stage Retry mode, Get the ShuffleServer list from the Driver based
on the shuffleId.
- shuffleHandleInfo =
- getRemoteShuffleHandleInfoWithStageRetry(
- context.stageId(), context.stageAttemptNumber(), shuffleId,
false);
- } else if (shuffleManagerRpcServiceEnabled && partitionReassignEnabled) {
- // In Block Retry mode, Get the ShuffleServer list from the Driver based
on the shuffleId.
- shuffleHandleInfo =
- getRemoteShuffleHandleInfoWithBlockRetry(
- context.stageId(), context.stageAttemptNumber(), shuffleId,
false);
+
+ if (shuffleManagerRpcServiceEnabled
+ && (rssStageRetryForWriteFailureEnabled || partitionReassignEnabled)) {
+ Supplier<ShuffleHandleInfo> func =
+ rssStageRetryForWriteFailureEnabled
+ ? () ->
+ getRemoteShuffleHandleInfoWithStageRetry(
+ context.stageId(), context.stageAttemptNumber(),
shuffleId, false)
+ : () ->
+ getRemoteShuffleHandleInfoWithBlockRetry(
+ context.stageId(), context.stageAttemptNumber(),
shuffleId, false);
+ if (readShuffleHandleCacheEnabled) {
+ shuffleHandleInfo = super.getOrFetchShuffleHandle(shuffleId, func);
+ } else {
+ shuffleHandleInfo = func.get();
+ }
} else {
shuffleHandleInfo =
new SimpleShuffleHandleInfo(
diff --git
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
index 877bf6e62..df6255c5f 100644
---
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
+++
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
@@ -17,14 +17,22 @@
package org.apache.spark.shuffle;
+import java.util.Collections;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Supplier;
+
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
+import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
import org.apache.spark.sql.internal.SQLConf;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
+import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.client.util.RssClientConfig;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleDataDistributionType;
@@ -33,6 +41,7 @@ import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.BlockIdLayout;
+import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.storage.util.StorageType;
import static
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;
@@ -42,6 +51,7 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@@ -252,4 +262,42 @@ public class RssShuffleManagerTest extends
RssShuffleManagerTestBase {
conf.set("spark.driver.host", "localhost");
return conf;
}
+
+ @Test
+ public void testReadCacheShuffleInfo() {
+ SparkConf conf = new SparkConf();
+ conf.setAppName("testApp")
+ .setMaster("local[2]")
+ .set(RssSparkConfig.RSS_TEST_FLAG.key(), "true")
+ .set(RssSparkConfig.RSS_TEST_MODE_ENABLE.key(), "true")
+ .set(RssSparkConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS.key(), "10000")
+ .set(RssSparkConfig.RSS_CLIENT_RETRY_MAX.key(), "10")
+ .set(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS.key(), "1000")
+ .set(RssSparkConfig.RSS_STORAGE_TYPE.key(),
StorageType.LOCALFILE.name())
+ .set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(),
"127.0.0.1:12345,127.0.0.1:12346");
+ Map<String, Set<Long>> successBlocks = JavaUtils.newConcurrentMap();
+ Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker =
JavaUtils.newConcurrentMap();
+ RssShuffleManager manager =
+ TestUtils.createShuffleManager(
+ conf, false, null, successBlocks, taskToFailedBlockSendTracker);
+
+ // case1: legal fetch and cache
+ Supplier<ShuffleHandleInfo> func1 =
+ () ->
+ new SimpleShuffleHandleInfo(
+ 1, Collections.emptyMap(),
RemoteStorageInfo.EMPTY_REMOTE_STORAGE);
+ ShuffleHandleInfo handle1 = manager.getOrFetchShuffleHandle(1, func1);
+ ShuffleHandleInfo handle2 = manager.getOrFetchShuffleHandle(1, func1);
+ assertEquals(handle1, handle2);
+
+ // case2: illegal fetch
+ manager.clearShuffleHandleCache();
+ Supplier<ShuffleHandleInfo> func2 = () -> null;
+ try {
+ ShuffleHandleInfo handle3 = manager.getOrFetchShuffleHandle(1, func2);
+ fail();
+ } catch (Exception e) {
+ // ignore
+ }
+ }
}
diff --git
a/integration-test/spark3/src/test/java/org/apache/uniffle/test/PartitionBlockDataReassignBasicTest.java
b/integration-test/spark3/src/test/java/org/apache/uniffle/test/PartitionBlockDataReassignBasicTest.java
index c0fbb3a22..9adc4b315 100644
---
a/integration-test/spark3/src/test/java/org/apache/uniffle/test/PartitionBlockDataReassignBasicTest.java
+++
b/integration-test/spark3/src/test/java/org/apache/uniffle/test/PartitionBlockDataReassignBasicTest.java
@@ -35,6 +35,7 @@ import org.apache.uniffle.server.ShuffleServerConf;
import org.apache.uniffle.server.buffer.ShuffleBufferManager;
import org.apache.uniffle.storage.util.StorageType;
+import static
org.apache.spark.shuffle.RssSparkConfig.RSS_READ_SHUFFLE_HANDLE_CACHE_ENABLED;
import static
org.apache.uniffle.client.util.RssClientConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER;
import static
org.apache.uniffle.client.util.RssClientConfig.RSS_CLIENT_RETRY_MAX;
import static
org.apache.uniffle.common.config.RssClientConf.RSS_CLIENT_REASSIGN_ENABLED;
@@ -99,6 +100,7 @@ public class PartitionBlockDataReassignBasicTest extends
SparkSQLTest {
"spark." + RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER,
String.valueOf(grpcShuffleServers.size()));
sparkConf.set("spark." + RSS_CLIENT_REASSIGN_ENABLED.key(), "true");
+ sparkConf.set("spark." + RSS_READ_SHUFFLE_HANDLE_CACHE_ENABLED.key(),
"true");
}
@Override