This is an automated email from the ASF dual-hosted git repository.
zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 496982261 [#1787] feat(spark): Fine grained stage retry switch for
fetch/write failure (#1788)
496982261 is described below
commit 496982261affd9a4dbcad168fa50470471ac9742
Author: Junfan Zhang <[email protected]>
AuthorDate: Fri Jun 21 09:34:50 2024 +0800
[#1787] feat(spark): Fine grained stage retry switch for fetch/write
failure (#1788)
### What changes were proposed in this pull request?
Introducing the independent fine grained switch for fetch/write failure on
stage retry mechanism
### Why are the changes needed?
1. The stage retry of write failure is supported recently, which is not
stable. Independent switch will benifit users using the fetch failure stage
retry which won't bring too much risk.
2. Leveraging the partition reassign mechanism, the importance of write
failure stage retry is decreasing.
### Does this PR introduce _any_ user-facing change?
Yes. The detailed doc will be added in the next following stage retry
improvement PRs.
### How was this patch tested?
Existing tests.
---
.../org/apache/spark/shuffle/RssSparkConfig.java | 30 ++++++++---
.../apache/spark/shuffle/RssSparkShuffleUtils.java | 4 +-
.../shuffle/manager/RssShuffleManagerBase.java | 17 ++++--
.../apache/spark/shuffle/RssShuffleManager.java | 35 ++++++++++---
.../spark/shuffle/reader/RssShuffleReader.java | 4 +-
.../spark/shuffle/writer/RssShuffleWriter.java | 4 +-
.../apache/spark/shuffle/RssShuffleManager.java | 39 ++++++++++----
.../spark/shuffle/reader/RssShuffleReader.java | 3 +-
.../spark/shuffle/writer/RssShuffleWriter.java | 3 +-
.../spark/shuffle/RssShuffleManagerTest.java | 60 ++++++++++++++++++++++
.../apache/uniffle/common/config/ConfigOption.java | 2 +-
11 files changed, 166 insertions(+), 35 deletions(-)
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index 455ac6d52..b47707d16 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -37,6 +37,29 @@ import org.apache.uniffle.common.config.RssConf;
public class RssSparkConfig {
+ public static final ConfigOption<Boolean> RSS_RESUBMIT_STAGE_ENABLED =
+ ConfigOptions.key("rss.stageRetry.enabled")
+ .booleanType()
+ .defaultValue(false)
+ .withDeprecatedKeys(RssClientConfig.RSS_RESUBMIT_STAGE)
+ .withDescription("Whether to enable the resubmit stage for
fetch/write failure");
+
+ public static final ConfigOption<Boolean>
RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED =
+ ConfigOptions.key("rss.stageRetry.fetchFailureEnabled")
+ .booleanType()
+ .defaultValue(false)
+ .withFallbackKeys(RSS_RESUBMIT_STAGE_ENABLED.key(),
RssClientConfig.RSS_RESUBMIT_STAGE)
+ .withDescription(
+ "If set to true, the stage retry mechanism will be enabled when
a fetch failure occurs.");
+
+ public static final ConfigOption<Boolean>
RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED =
+ ConfigOptions.key("rss.stageRetry.writeFailureEnabled")
+ .booleanType()
+ .defaultValue(false)
+ .withFallbackKeys(RSS_RESUBMIT_STAGE_ENABLED.key(),
RssClientConfig.RSS_RESUBMIT_STAGE)
+ .withDescription(
+ "If set to true, the stage retry mechanism will be enabled when
a write failure occurs.");
+
public static final ConfigOption<Boolean>
RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED =
ConfigOptions.key("rss.blockId.selfManagementEnabled")
.booleanType()
@@ -404,13 +427,6 @@ public class RssSparkConfig {
.doc(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT.description()))
.createWithDefault(-1);
- public static final ConfigEntry<Boolean> RSS_RESUBMIT_STAGE =
- createBooleanBuilder(
- new ConfigBuilder(SPARK_RSS_CONFIG_PREFIX +
RssClientConfig.RSS_RESUBMIT_STAGE)
- .internal()
- .doc("Whether to enable the resubmit stage."))
- .createWithDefault(false);
-
public static final ConfigEntry<Integer> RSS_MAX_PARTITIONS =
createIntegerBuilder(
new ConfigBuilder("spark.rss.blockId.maxPartitions")
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
index 51384f180..b3763df32 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
@@ -47,7 +47,6 @@ import
org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
import org.apache.uniffle.client.util.ClientUtils;
-import org.apache.uniffle.client.util.RssClientConfig;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
@@ -57,6 +56,7 @@ import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.exception.RssFetchFailedException;
import org.apache.uniffle.common.util.Constants;
+import static
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;
import static org.apache.uniffle.common.util.Constants.DRIVER_HOST;
public class RssSparkShuffleUtils {
@@ -353,7 +353,7 @@ public class RssSparkShuffleUtils {
int stageAttemptId,
Set<Integer> failedPartitions) {
RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
- if (rssConf.getBoolean(RssClientConfig.RSS_RESUBMIT_STAGE, false)
+ if (rssConf.getBoolean(RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED)
&& RssSparkShuffleUtils.isStageResubmitSupported()) {
String driver = rssConf.getString(DRIVER_HOST, "");
int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
diff --git
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
index 25acf1f0f..209ede25c 100644
---
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
+++
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
@@ -103,8 +103,9 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
protected SparkConf sparkConf;
protected ShuffleManagerClient shuffleManagerClient;
- /** Whether to enable the dynamic shuffleServer function rewrite and reread
functions */
- protected boolean rssResubmitStage;
+ protected boolean rssStageRetryEnabled;
+ protected boolean rssStageRetryForWriteFailureEnabled;
+ protected boolean rssStageRetryForFetchFailureEnabled;
/**
* Mapping between ShuffleId and ShuffleServer list. ShuffleServer list is
dynamically allocated.
* ShuffleServer is not obtained from RssShuffleHandle, but from this
mapping.
@@ -1046,7 +1047,15 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
appId, defaultRemoteStorage, dynamicConfEnabled, storageType,
shuffleWriteClient);
}
- public boolean isRssResubmitStage() {
- return rssResubmitStage;
+ public boolean isRssStageRetryEnabled() {
+ return rssStageRetryEnabled;
+ }
+
+ public boolean isRssStageRetryForWriteFailureEnabled() {
+ return rssStageRetryForWriteFailureEnabled;
+ }
+
+ public boolean isRssStageRetryForFetchFailureEnabled() {
+ return rssStageRetryForFetchFailureEnabled;
}
}
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 974b99862..de5d4da63 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -18,6 +18,7 @@
package org.apache.spark.shuffle;
import java.io.IOException;
+import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@@ -33,6 +34,7 @@ import scala.collection.Seq;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Sets;
+import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
@@ -56,7 +58,6 @@ import org.slf4j.LoggerFactory;
import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.client.util.ClientUtils;
-import org.apache.uniffle.client.util.RssClientConfig;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssClientConf;
@@ -74,6 +75,8 @@ import
org.apache.uniffle.shuffle.manager.ShuffleManagerGrpcService;
import org.apache.uniffle.shuffle.manager.ShuffleManagerServerFactory;
import static
org.apache.spark.shuffle.RssSparkConfig.RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED;
+import static
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;
+import static
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED;
import static org.apache.uniffle.common.config.RssBaseConf.RPC_SERVER_PORT;
import static
org.apache.uniffle.common.config.RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE;
@@ -165,13 +168,29 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
// shuffle cluster, we don't need shuffle data locality
sparkConf.set("spark.shuffle.reduceLocality.enabled", "false");
LOG.info("Disable shuffle data locality in RssShuffleManager.");
- this.rssResubmitStage =
- rssConf.getBoolean(RssClientConfig.RSS_RESUBMIT_STAGE, false)
- && RssSparkShuffleUtils.isStageResubmitSupported();
+
+ // stage retry for write/fetch failure
+ rssStageRetryForFetchFailureEnabled =
+ rssConf.get(RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED);
+ rssStageRetryForWriteFailureEnabled =
+ rssConf.get(RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED);
+ if (rssStageRetryForFetchFailureEnabled ||
rssStageRetryForWriteFailureEnabled) {
+ rssStageRetryEnabled = true;
+ List<String> logTips = new ArrayList<>();
+ if (rssStageRetryForWriteFailureEnabled) {
+ logTips.add("write");
+ }
+ if (rssStageRetryForWriteFailureEnabled) {
+ logTips.add("fetch");
+ }
+ LOG.info(
+ "Activate the stage retry mechanism that will resubmit stage on {}
failure",
+ StringUtils.join(logTips, "/"));
+ }
this.partitionReassignEnabled =
rssConf.getBoolean(RssClientConf.RSS_CLIENT_REASSIGN_ENABLED);
this.blockIdSelfManagedEnabled =
rssConf.getBoolean(RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED);
this.shuffleManagerRpcServiceEnabled =
- partitionReassignEnabled || rssResubmitStage ||
blockIdSelfManagedEnabled;
+ partitionReassignEnabled || rssStageRetryEnabled ||
blockIdSelfManagedEnabled;
if (!sparkConf.getBoolean(RssSparkConfig.RSS_TEST_FLAG.key(), false)) {
if (isDriver) {
heartBeatScheduledExecutorService =
@@ -334,7 +353,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
shuffleIdToPartitionNum.putIfAbsent(shuffleId,
dependency.partitioner().numPartitions());
shuffleIdToNumMapTasks.putIfAbsent(shuffleId,
dependency.rdd().partitions().length);
- if (shuffleManagerRpcServiceEnabled && rssResubmitStage) {
+ if (shuffleManagerRpcServiceEnabled && rssStageRetryEnabled) {
ShuffleHandleInfo handleInfo =
new MutableShuffleHandleInfo(shuffleId, partitionToServers,
remoteStorage);
StageAttemptShuffleHandleInfo stageAttemptShuffleHandleInfo =
@@ -406,7 +425,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
int shuffleId = rssHandle.getShuffleId();
String taskId = "" + context.taskAttemptId() + "_" +
context.attemptNumber();
ShuffleHandleInfo shuffleHandleInfo;
- if (shuffleManagerRpcServiceEnabled && rssResubmitStage) {
+ if (shuffleManagerRpcServiceEnabled && rssStageRetryEnabled) {
// In Stage Retry mode, Get the ShuffleServer list from the Driver
based on the shuffleId
shuffleHandleInfo =
getRemoteShuffleHandleInfoWithStageRetry(shuffleId);
} else if (shuffleManagerRpcServiceEnabled && partitionReassignEnabled) {
@@ -479,7 +498,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
+ "]");
start = System.currentTimeMillis();
ShuffleHandleInfo shuffleHandleInfo;
- if (shuffleManagerRpcServiceEnabled && rssResubmitStage) {
+ if (shuffleManagerRpcServiceEnabled && rssStageRetryEnabled) {
// In Stage Retry mode, Get the ShuffleServer list from the Driver
based on the shuffleId.
shuffleHandleInfo =
getRemoteShuffleHandleInfoWithStageRetry(shuffleId);
} else if (shuffleManagerRpcServiceEnabled && partitionReassignEnabled) {
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index 8f5118e68..3bf5840e8 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -54,6 +54,8 @@ import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
+import static
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;
+
public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
private static final Logger LOG =
LoggerFactory.getLogger(RssShuffleReader.class);
@@ -231,7 +233,7 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
}
// stage re-compute and shuffle manager server port are both set
- if (rssConf.getBoolean(RssClientConfig.RSS_RESUBMIT_STAGE, false)
+ if (rssConf.getBoolean(RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED)
&& rssConf.getInteger(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT, 0) > 0)
{
String driver = rssConf.getString("driver.host", "");
int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 2689ee39c..ee552d325 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -80,6 +80,8 @@ import
org.apache.uniffle.common.exception.RssSendFailedException;
import org.apache.uniffle.common.exception.RssWaitFailedException;
import org.apache.uniffle.storage.util.StorageType;
+import static
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED;
+
public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
private static final Logger LOG =
LoggerFactory.getLogger(RssShuffleWriter.class);
@@ -238,7 +240,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
writeImpl(records);
} catch (Exception e) {
taskFailureCallback.apply(taskId);
- if (shuffleManager.isRssResubmitStage()) {
+ if
(RssSparkConfig.toRssConf(sparkConf).get(RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED))
{
throwFetchFailedIfNecessary(e);
} else {
throw e;
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 9e6f2a26f..1d5050790 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -18,6 +18,7 @@
package org.apache.spark.shuffle;
import java.io.IOException;
+import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@@ -37,6 +38,7 @@ import scala.collection.Seq;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Sets;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
+import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.MapOutputTracker;
import org.apache.spark.ShuffleDependency;
@@ -64,7 +66,6 @@ import org.slf4j.LoggerFactory;
import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.client.util.ClientUtils;
-import org.apache.uniffle.client.util.RssClientConfig;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
@@ -84,6 +85,8 @@ import
org.apache.uniffle.shuffle.manager.ShuffleManagerServerFactory;
import static
org.apache.spark.shuffle.RssSparkConfig.RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED;
import static
org.apache.spark.shuffle.RssSparkConfig.RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM;
+import static
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;
+import static
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED;
import static org.apache.uniffle.common.config.RssBaseConf.RPC_SERVER_PORT;
import static
org.apache.uniffle.common.config.RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE;
@@ -179,10 +182,28 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
LOG.info("Disable shuffle data locality in RssShuffleManager.");
taskToSuccessBlockIds = JavaUtils.newConcurrentMap();
taskToFailedBlockSendTracker = JavaUtils.newConcurrentMap();
- this.rssResubmitStage =
- rssConf.getBoolean(RssClientConfig.RSS_RESUBMIT_STAGE, false)
- && RssSparkShuffleUtils.isStageResubmitSupported();
- this.partitionReassignEnabled =
rssConf.getBoolean(RssClientConf.RSS_CLIENT_REASSIGN_ENABLED);
+
+ this.rssStageRetryEnabled =
rssConf.get(RssClientConf.RSS_CLIENT_REASSIGN_ENABLED);
+ this.partitionReassignEnabled =
rssConf.get(RssClientConf.RSS_CLIENT_REASSIGN_ENABLED);
+
+ // stage retry for write/fetch failure
+ rssStageRetryForFetchFailureEnabled =
+ rssConf.get(RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED);
+ rssStageRetryForWriteFailureEnabled =
+ rssConf.get(RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED);
+ if (rssStageRetryForFetchFailureEnabled ||
rssStageRetryForWriteFailureEnabled) {
+ rssStageRetryEnabled = true;
+ List<String> logTips = new ArrayList<>();
+ if (rssStageRetryForWriteFailureEnabled) {
+ logTips.add("write");
+ }
+ if (rssStageRetryForWriteFailureEnabled) {
+ logTips.add("fetch");
+ }
+ LOG.info(
+ "Activate the stage retry mechanism that will resubmit stage on {}
failure",
+ StringUtils.join(logTips, "/"));
+ }
// The feature of partition reassign is exclusive with multiple replicas
and stage retry.
if (partitionReassignEnabled) {
@@ -194,7 +215,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
this.blockIdSelfManagedEnabled =
rssConf.getBoolean(RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED);
this.shuffleManagerRpcServiceEnabled =
- partitionReassignEnabled || rssResubmitStage ||
blockIdSelfManagedEnabled;
+ partitionReassignEnabled || rssStageRetryEnabled ||
blockIdSelfManagedEnabled;
if (isDriver) {
heartBeatScheduledExecutorService =
ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat");
@@ -448,7 +469,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
startHeartbeat();
shuffleIdToPartitionNum.putIfAbsent(shuffleId,
dependency.partitioner().numPartitions());
shuffleIdToNumMapTasks.putIfAbsent(shuffleId,
dependency.rdd().partitions().length);
- if (shuffleManagerRpcServiceEnabled && rssResubmitStage) {
+ if (shuffleManagerRpcServiceEnabled && rssStageRetryEnabled) {
ShuffleHandleInfo shuffleHandleInfo =
new MutableShuffleHandleInfo(shuffleId, partitionToServers,
remoteStorage);
StageAttemptShuffleHandleInfo handleInfo =
@@ -492,7 +513,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
writeMetrics = context.taskMetrics().shuffleWriteMetrics();
}
ShuffleHandleInfo shuffleHandleInfo;
- if (shuffleManagerRpcServiceEnabled && rssResubmitStage) {
+ if (shuffleManagerRpcServiceEnabled && rssStageRetryEnabled) {
// In Stage Retry mode, Get the ShuffleServer list from the Driver based
on the shuffleId.
shuffleHandleInfo = getRemoteShuffleHandleInfoWithStageRetry(shuffleId);
} else if (shuffleManagerRpcServiceEnabled && partitionReassignEnabled) {
@@ -636,7 +657,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
final int partitionNum =
rssShuffleHandle.getDependency().partitioner().numPartitions();
int shuffleId = rssShuffleHandle.getShuffleId();
ShuffleHandleInfo shuffleHandleInfo;
- if (shuffleManagerRpcServiceEnabled && rssResubmitStage) {
+ if (shuffleManagerRpcServiceEnabled && rssStageRetryEnabled) {
// In Stage Retry mode, Get the ShuffleServer list from the Driver based
on the shuffleId.
shuffleHandleInfo = getRemoteShuffleHandleInfoWithStageRetry(shuffleId);
} else if (shuffleManagerRpcServiceEnabled && partitionReassignEnabled) {
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index 9b176340b..d76abc4f7 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -56,6 +56,7 @@ import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
+import static
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;
import static org.apache.uniffle.common.util.Constants.DRIVER_HOST;
public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
@@ -189,7 +190,7 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
resultIter = new InterruptibleIterator<>(context, resultIter);
}
// resubmit stage and shuffle manager server port are both set
- if (rssConf.getBoolean(RssClientConfig.RSS_RESUBMIT_STAGE, false)
+ if (rssConf.getBoolean(RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED)
&& rssConf.getInteger(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT, 0) > 0)
{
String driver = rssConf.getString(DRIVER_HOST, "");
int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 37948b6b8..1d283cd94 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -92,6 +92,7 @@ import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.storage.util.StorageType;
import static
org.apache.spark.shuffle.RssSparkConfig.RSS_PARTITION_REASSIGN_BLOCK_RETRY_MAX_TIMES;
+import static
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED;
public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
@@ -277,7 +278,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
writeImpl(records);
} catch (Exception e) {
taskFailureCallback.apply(taskId);
- if (shuffleManager.isRssResubmitStage()) {
+ if
(RssSparkConfig.toRssConf(sparkConf).get(RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED))
{
throwFetchFailedIfNecessary(e);
} else {
throw e;
diff --git
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
index 66b2c9a44..2a92b6ed5 100644
---
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
+++
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
@@ -27,14 +27,17 @@ import org.junit.jupiter.params.provider.ValueSource;
import org.apache.uniffle.client.util.RssClientConfig;
import org.apache.uniffle.common.ShuffleDataDistributionType;
+import org.apache.uniffle.common.config.ConfigOption;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.storage.util.StorageType;
+import static
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;
import static
org.apache.spark.shuffle.RssSparkConfig.RSS_SHUFFLE_MANAGER_GRPC_PORT;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -108,6 +111,8 @@ public class RssShuffleManagerTest extends
RssShuffleManagerTestBase {
RssShuffleManager shuffleManager = new RssShuffleManager(conf, true);
+ ConfigOption<Boolean> a = RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;
+
assertTrue(conf.get(RSS_SHUFFLE_MANAGER_GRPC_PORT) > 0);
}
@@ -171,4 +176,59 @@ public class RssShuffleManagerTest extends
RssShuffleManagerTestBase {
+ " partitions.",
e.getMessage());
}
+
+ @Test
+ public void testWithStageRetry() {
+ // case1: disable the stage retry
+ SparkConf conf = createSparkConf();
+ RssShuffleManager shuffleManager = new RssShuffleManager(conf, true);
+ assertFalse(shuffleManager.isRssStageRetryEnabled());
+ assertFalse(shuffleManager.isRssStageRetryForFetchFailureEnabled());
+ assertFalse(shuffleManager.isRssStageRetryForWriteFailureEnabled());
+ shuffleManager.stop();
+
+ // case2: enable the stage retry
+ conf.set(
+ RssSparkConfig.SPARK_RSS_CONFIG_PREFIX +
RssSparkConfig.RSS_RESUBMIT_STAGE_ENABLED.key(),
+ "true");
+ shuffleManager = new RssShuffleManager(conf, true);
+ assertTrue(shuffleManager.isRssStageRetryEnabled());
+ assertTrue(shuffleManager.isRssStageRetryForFetchFailureEnabled());
+ assertTrue(shuffleManager.isRssStageRetryForWriteFailureEnabled());
+ shuffleManager.stop();
+
+ // case3: overwrite the stage retry
+ conf.set(
+ RssSparkConfig.SPARK_RSS_CONFIG_PREFIX
+ + RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED.key(),
+ "false");
+ shuffleManager = new RssShuffleManager(conf, true);
+ assertTrue(shuffleManager.isRssStageRetryEnabled());
+ assertFalse(shuffleManager.isRssStageRetryForFetchFailureEnabled());
+ assertTrue(shuffleManager.isRssStageRetryForWriteFailureEnabled());
+ shuffleManager.stop();
+
+ // case4: enable the partial stage retry of fetch failure
+ conf = createSparkConf();
+ conf.set(
+ RssSparkConfig.SPARK_RSS_CONFIG_PREFIX
+ + RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED.key(),
+ "true");
+ shuffleManager = new RssShuffleManager(conf, true);
+ assertTrue(shuffleManager.isRssStageRetryEnabled());
+ assertTrue(shuffleManager.isRssStageRetryForFetchFailureEnabled());
+ assertFalse(shuffleManager.isRssStageRetryForWriteFailureEnabled());
+ shuffleManager.stop();
+ }
+
+ private SparkConf createSparkConf() {
+ SparkConf conf = new SparkConf();
+ conf.set(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED.key(), "false");
+ conf.set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), "m1:8001,m2:8002");
+ conf.set("spark.rss.storage.type", StorageType.LOCALFILE.name());
+ conf.set(RssSparkConfig.RSS_TEST_MODE_ENABLE, true);
+ conf.set("spark.task.maxFailures", "4");
+ conf.set("spark.driver.host", "localhost");
+ return conf;
+ }
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/config/ConfigOption.java
b/common/src/main/java/org/apache/uniffle/common/config/ConfigOption.java
index a507eb110..7c7319279 100644
--- a/common/src/main/java/org/apache/uniffle/common/config/ConfigOption.java
+++ b/common/src/main/java/org/apache/uniffle/common/config/ConfigOption.java
@@ -86,7 +86,7 @@ public class ConfigOption<T> {
* @return A new config option, with given description.
*/
public ConfigOption<T> withDescription(final String description) {
- return new ConfigOption<>(key, clazz, description, defaultValue,
converter);
+ return new ConfigOption<>(key, clazz, description, defaultValue,
converter, fallbackKeys);
}
/**