This is an automated email from the ASF dual-hosted git repository.
fchen pushed a commit to branch branch-0.4
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/branch-0.4 by this push:
new f7760359a [CELEBORN-1496][0.4] Differentiate map results with only
different stageAttemptId
f7760359a is described below
commit f7760359a5d92f028852a3435b587ab6a6e2f529
Author: jiang13021 <[email protected]>
AuthorDate: Mon Sep 2 13:45:57 2024 +0800
[CELEBORN-1496][0.4] Differentiate map results with only different
stageAttemptId
backport https://github.com/apache/celeborn/pull/2609 to branch-0.4
### What changes were proposed in this pull request?
Let attemptNumber = (stageAttemptId << 16) | taskAttemptNumber, to
differentiate map results with only different stageAttemptId.
### Why are the changes needed?
If we can't differentiate map tasks with only different stageAttemptId, it
may lead to mixed reading of two map tasks' shuffle write batches during
shuffle read, causing data correctness issue.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Add ut:
org.apache.spark.shuffle.celeborn.SparkShuffleManagerSuite#testWrongSparkConf_MaxAttemptLimit
Closes #2609 from jiang13021/spark_stage_attempt_id.
Lead-authored-by: jiang13021 <jiangyanze.jyzantgroup.com>
Closes #2717 from cfmcgrady/CELEBORN-1496-branch-0.4.
Authored-by: jiang13021 <[email protected]>
Signed-off-by: Fu Chen <[email protected]>
---
.../spark/shuffle/celeborn/SparkCommonUtils.java | 53 ++++++++++++++++++++++
.../shuffle/celeborn/HashBasedShuffleWriter.java | 16 ++++---
.../shuffle/celeborn/SortBasedShuffleWriter.java | 12 +++--
.../shuffle/celeborn/SparkShuffleManager.java | 1 +
.../shuffle/celeborn/CelebornShuffleReader.scala | 3 +-
.../celeborn/CelebornShuffleManagerSuite.scala | 30 ++++++++++++
.../CelebornColumnarShuffleReaderSuite.scala | 12 +++--
.../shuffle/celeborn/HashBasedShuffleWriter.java | 16 ++++---
.../shuffle/celeborn/SortBasedShuffleWriter.java | 12 +++--
.../shuffle/celeborn/SparkShuffleManager.java | 1 +
.../shuffle/celeborn/CelebornShuffleReader.scala | 3 +-
.../celeborn/CelebornShuffleManagerSuite.scala | 30 ++++++++++++
12 files changed, 160 insertions(+), 29 deletions(-)
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java
new file mode 100644
index 000000000..f00fe063e
--- /dev/null
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.scheduler.DAGScheduler;
+
+public class SparkCommonUtils {
+ public static void validateAttemptConfig(SparkConf conf) throws
IllegalArgumentException {
+ int maxStageAttempts =
+ conf.getInt(
+ "spark.stage.maxConsecutiveAttempts",
+ DAGScheduler.DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS());
+ // In Spark 2, the parameter is referred to as MAX_TASK_FAILURES, while in
Spark 3, it has been
+ // changed to TASK_MAX_FAILURES. The default value for both is
consistently set to 4.
+ int maxTaskAttempts = conf.getInt("spark.task.maxFailures", 4);
+ if (maxStageAttempts >= (1 << 15) || maxTaskAttempts >= (1 << 16)) {
+ // The map attemptId is a non-negative number constructed from
+ // both stageAttemptNumber and taskAttemptNumber.
+ // The high 16 bits of the map attemptId are used for the
stageAttemptNumber,
+ // and the low 16 bits are used for the taskAttemptNumber.
+ // So spark.stage.maxConsecutiveAttempts should be less than 32768 (1 <<
15)
+ // and spark.task.maxFailures should be less than 65536 (1 << 16).
+ throw new IllegalArgumentException(
+ "The spark.stage.maxConsecutiveAttempts should be less than 32768
(currently "
+ + maxStageAttempts
+ + ")"
+ + "and spark.task.maxFailures should be less than 65536
(currently "
+ + maxTaskAttempts
+ + ").");
+ }
+ }
+
+ public static int getEncodedAttemptNumber(TaskContext context) {
+ return (context.stageAttemptNumber() << 16) | context.attemptNumber();
+ }
+}
diff --git
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
index 407b32849..06d7ccc72 100644
---
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
+++
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
@@ -68,6 +68,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
private final ShuffleWriteMetrics writeMetrics;
private final int shuffleId;
private final int mapId;
+ private final int encodedAttemptId;
private final TaskContext taskContext;
private final ShuffleClient shuffleClient;
private final int numMappers;
@@ -112,6 +113,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
this.mapId = mapId;
this.dep = handle.dependency();
this.shuffleId = shuffleId;
+ this.encodedAttemptId =
SparkCommonUtils.getEncodedAttemptNumber(taskContext);
SerializerInstance serializer = dep.serializer().newInstance();
this.partitioner = dep.partitioner();
this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
@@ -146,7 +148,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
new DataPusher(
shuffleId,
mapId,
- taskContext.attemptNumber(),
+ encodedAttemptId,
taskContext.taskAttemptId(),
numMappers,
numPartitions,
@@ -273,7 +275,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
shuffleClient.pushData(
shuffleId,
mapId,
- taskContext.attemptNumber(),
+ encodedAttemptId,
partitionId,
buffer,
0,
@@ -318,7 +320,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
// here we wait for all the in-flight batches to return which sent by
dataPusher thread
dataPusher.waitOnTermination();
sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
- shuffleClient.prepareForMergeData(shuffleId, mapId,
taskContext.attemptNumber());
+ shuffleClient.prepareForMergeData(shuffleId, mapId, encodedAttemptId);
// merge and push residual data to reduce network traffic
// NB: since dataPusher thread have no in-flight data at this point,
@@ -330,7 +332,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
shuffleClient.mergeData(
shuffleId,
mapId,
- taskContext.attemptNumber(),
+ encodedAttemptId,
i,
sendBuffers[i],
0,
@@ -343,7 +345,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
writeMetrics.incBytesWritten(bytesWritten);
}
}
- shuffleClient.pushMergedData(shuffleId, mapId,
taskContext.attemptNumber());
+ shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
updateMapStatus();
@@ -352,7 +354,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
sendOffsets = null;
long waitStartTime = System.nanoTime();
- shuffleClient.mapperEnd(shuffleId, mapId, taskContext.attemptNumber(),
numMappers);
+ shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
BlockManagerId bmId = SparkEnv.get().blockManager().shuffleServerId();
@@ -389,7 +391,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
}
}
} finally {
- shuffleClient.cleanup(shuffleId, mapId, taskContext.attemptNumber());
+ shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
}
}
}
diff --git
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
index 0985c9b69..a8bd23c21 100644
---
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
+++
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
@@ -62,6 +62,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
private final ShuffleWriteMetrics writeMetrics;
private final int shuffleId;
private final int mapId;
+ private final int encodedAttemptId;
private final TaskContext taskContext;
private final ShuffleClient shuffleClient;
private final int numMappers;
@@ -102,6 +103,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
this.mapId = taskContext.partitionId();
this.dep = dep;
this.shuffleId = shuffleId;
+ this.encodedAttemptId =
SparkCommonUtils.getEncodedAttemptNumber(taskContext);
SerializerInstance serializer = dep.serializer().newInstance();
this.partitioner = dep.partitioner();
this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
@@ -130,7 +132,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
taskContext,
shuffleId,
mapId,
- taskContext.attemptNumber(),
+ encodedAttemptId,
taskContext.taskAttemptId(),
numMappers,
numPartitions,
@@ -277,7 +279,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
shuffleClient.pushData(
shuffleId,
mapId,
- taskContext.attemptNumber(),
+ encodedAttemptId,
partitionId,
buffer,
0,
@@ -295,12 +297,12 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
pusher.close();
writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);
- shuffleClient.pushMergedData(shuffleId, mapId,
taskContext.attemptNumber());
+ shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
updateMapStatus();
long waitStartTime = System.nanoTime();
- shuffleClient.mapperEnd(shuffleId, mapId, taskContext.attemptNumber(),
numMappers);
+ shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
}
@@ -336,7 +338,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
} catch (IOException e) {
return Option.apply(null);
} finally {
- shuffleClient.cleanup(shuffleId, mapId, taskContext.attemptNumber());
+ shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
}
}
}
diff --git
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index 157fc2d0b..cabecfb15 100644
---
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -64,6 +64,7 @@ public class SparkShuffleManager implements ShuffleManager {
private ExecutorShuffleIdTracker shuffleIdTracker = new
ExecutorShuffleIdTracker();
public SparkShuffleManager(SparkConf conf, boolean isDriver) {
+ SparkCommonUtils.validateAttemptConfig(conf);
this.conf = conf;
this.isDriver = isDriver;
this.celebornConf = SparkUtils.fromSparkConf(conf);
diff --git
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index b74db710c..056b94c94 100644
---
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -56,6 +56,7 @@ class CelebornShuffleReader[K, C](
handle.extension)
private val exceptionRef = new AtomicReference[IOException]
+ private val encodedAttemptId =
SparkCommonUtils.getEncodedAttemptNumber(context)
override def read(): Iterator[Product2[K, C]] = {
@@ -96,7 +97,7 @@ class CelebornShuffleReader[K, C](
val inputStream = shuffleClient.readPartition(
shuffleId,
partitionId,
- context.attemptNumber(),
+ encodedAttemptId,
startMapIndex,
endMapIndex,
metricsCallback)
diff --git
a/client-spark/spark-2/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
b/client-spark/spark-2/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
index cb6c3e0ab..0786904b8 100644
---
a/client-spark/spark-2/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
+++
b/client-spark/spark-2/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.shuffle.celeborn
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.internal.SQLConf
import org.junit
+import org.junit.Assert
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
@@ -67,4 +68,33 @@ class SparkShuffleManagerSuite extends Logging {
sc.stop()
}
+ @junit.Test
+ def testWrongSparkConfMaxAttemptLimit(): Unit = {
+ val conf = new SparkConf().setIfMissing("spark.master", "local")
+ .setIfMissing(
+ "spark.shuffle.manager",
+ "org.apache.spark.shuffle.celeborn.SparkShuffleManager")
+ .set(s"spark.${CelebornConf.MASTER_ENDPOINTS.key}", "localhost:9097")
+ .set(s"spark.${CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key}", "false")
+ .set("spark.shuffle.service.enabled", "false")
+ .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
+
+ // default conf, will success
+ new SparkShuffleManager(conf, true)
+
+ conf
+ .set("spark.stage.maxConsecutiveAttempts", "32768")
+ .set("spark.task.maxFailures", "10")
+ try {
+ new SparkShuffleManager(conf, true)
+ Assert.fail()
+ } catch {
+ case e: IllegalArgumentException =>
+ Assert.assertTrue(
+ e.getMessage.contains("The spark.stage.maxConsecutiveAttempts should
be less than 32768"))
+ case _: Throwable =>
+ Assert.fail()
+ }
+ }
+
}
diff --git
a/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
b/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
index 5df434f54..158cb9406 100644
---
a/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
+++
b/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.shuffle.celeborn
-import org.apache.spark.{ShuffleDependency, SparkConf}
+import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext}
import org.apache.spark.serializer.{KryoSerializer, KryoSerializerInstance}
import org.apache.spark.sql.execution.UnsafeRowSerializer
import
org.apache.spark.sql.execution.columnar.CelebornColumnarBatchSerializerInstance
@@ -45,6 +45,9 @@ class CelebornColumnarShuffleReaderSuite {
var shuffleClient: MockedStatic[ShuffleClient] = null
try {
+ val taskContext = Mockito.mock(classOf[TaskContext])
+ Mockito.when(taskContext.stageAttemptNumber).thenReturn(0)
+ Mockito.when(taskContext.attemptNumber).thenReturn(0)
shuffleClient = Mockito.mockStatic(classOf[ShuffleClient])
val shuffleReader = SparkUtils.createColumnarShuffleReader(
handle,
@@ -52,7 +55,7 @@ class CelebornColumnarShuffleReaderSuite {
10,
0,
10,
- null,
+ taskContext,
new CelebornConf(),
null,
new ExecutorShuffleIdTracker())
@@ -68,6 +71,9 @@ class CelebornColumnarShuffleReaderSuite {
def columnarShuffleReaderNewSerializerInstance(): Unit = {
var shuffleClient: MockedStatic[ShuffleClient] = null
try {
+ val taskContext = Mockito.mock(classOf[TaskContext])
+ Mockito.when(taskContext.stageAttemptNumber).thenReturn(0)
+ Mockito.when(taskContext.attemptNumber).thenReturn(0)
shuffleClient = Mockito.mockStatic(classOf[ShuffleClient])
val shuffleReader = SparkUtils.createColumnarShuffleReader(
new CelebornShuffleHandle[Int, String, String](
@@ -83,7 +89,7 @@ class CelebornColumnarShuffleReaderSuite {
10,
0,
10,
- null,
+ taskContext,
new CelebornConf(),
null,
new ExecutorShuffleIdTracker())
diff --git
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
index a3024bf00..c3808b6f1 100644
---
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
+++
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
@@ -68,6 +68,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
private final ShuffleWriteMetricsReporter writeMetrics;
private final int shuffleId;
private final int mapId;
+ private final int encodedAttemptId;
private final TaskContext taskContext;
private final ShuffleClient shuffleClient;
private final int numMappers;
@@ -112,6 +113,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
this.mapId = taskContext.partitionId();
this.dep = handle.dependency();
this.shuffleId = shuffleId;
+ this.encodedAttemptId =
SparkCommonUtils.getEncodedAttemptNumber(taskContext);
SerializerInstance serializer = dep.serializer().newInstance();
this.partitioner = dep.partitioner();
this.writeMetrics = metrics;
@@ -143,7 +145,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
new DataPusher(
shuffleId,
mapId,
- taskContext.attemptNumber(),
+ encodedAttemptId,
taskContext.taskAttemptId(),
numMappers,
numPartitions,
@@ -278,7 +280,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
shuffleClient.pushData(
shuffleId,
mapId,
- taskContext.attemptNumber(),
+ encodedAttemptId,
partitionId,
buffer,
0,
@@ -342,7 +344,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
shuffleClient.mergeData(
shuffleId,
mapId,
- taskContext.attemptNumber(),
+ encodedAttemptId,
partitionId,
buffer,
offset,
@@ -358,15 +360,15 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
long pushMergedDataTime = System.nanoTime();
dataPusher.waitOnTermination();
sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
- shuffleClient.prepareForMergeData(shuffleId, mapId,
taskContext.attemptNumber());
+ shuffleClient.prepareForMergeData(shuffleId, mapId, encodedAttemptId);
closeWrite();
- shuffleClient.pushMergedData(shuffleId, mapId,
taskContext.attemptNumber());
+ shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
writeMetrics.incWriteTime(System.nanoTime() - pushMergedDataTime);
updateMapStatus();
long waitStartTime = System.nanoTime();
- shuffleClient.mapperEnd(shuffleId, mapId, taskContext.attemptNumber(),
numMappers);
+ shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
BlockManagerId bmId = SparkEnv.get().blockManager().shuffleServerId();
@@ -403,7 +405,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
}
}
} finally {
- shuffleClient.cleanup(shuffleId, mapId, taskContext.attemptNumber());
+ shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
}
}
diff --git
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
index cbb8ec725..7984f9ec8 100644
---
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
+++
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
@@ -62,6 +62,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
private final ShuffleWriteMetricsReporter writeMetrics;
private final int shuffleId;
private final int mapId;
+ private final int encodedAttemptId;
private final TaskContext taskContext;
private final ShuffleClient shuffleClient;
private final int numMappers;
@@ -102,6 +103,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
this.mapId = taskContext.partitionId();
this.dep = dep;
this.shuffleId = shuffleId;
+ this.encodedAttemptId =
SparkCommonUtils.getEncodedAttemptNumber(taskContext);
SerializerInstance serializer = dep.serializer().newInstance();
this.partitioner = dep.partitioner();
this.writeMetrics = metrics;
@@ -129,7 +131,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
taskContext,
shuffleId,
mapId,
- taskContext.attemptNumber(),
+ encodedAttemptId,
taskContext.taskAttemptId(),
numMappers,
numPartitions,
@@ -298,7 +300,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
shuffleClient.pushData(
shuffleId,
mapId,
- taskContext.attemptNumber(),
+ encodedAttemptId,
partitionId,
buffer,
0,
@@ -315,13 +317,13 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
pusher.pushData();
pusher.close();
- shuffleClient.pushMergedData(shuffleId, mapId,
taskContext.attemptNumber());
+ shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);
updateMapStatus();
long waitStartTime = System.nanoTime();
- shuffleClient.mapperEnd(shuffleId, mapId, taskContext.attemptNumber(),
numMappers);
+ shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
}
@@ -355,7 +357,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
}
}
} finally {
- shuffleClient.cleanup(shuffleId, mapId, taskContext.attemptNumber());
+ shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
}
}
diff --git
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index d97e774bf..35cc39846 100644
---
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -94,6 +94,7 @@ public class SparkShuffleManager implements ShuffleManager {
+ "use Celeborn as Remote Shuffle Service to avoid performance
degradation.",
SQLConf.LOCAL_SHUFFLE_READER_ENABLED().key());
}
+ SparkCommonUtils.validateAttemptConfig(conf);
this.conf = conf;
this.isDriver = isDriver;
this.celebornConf = SparkUtils.fromSparkConf(conf);
diff --git
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index 32c618bc2..9ba26116e 100644
---
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -58,6 +58,7 @@ class CelebornShuffleReader[K, C](
private val exceptionRef = new AtomicReference[IOException]
private val throwsFetchFailure = handle.throwsFetchFailure
+ private val encodedAttemptId =
SparkCommonUtils.getEncodedAttemptNumber(context)
override def read(): Iterator[Product2[K, C]] = {
@@ -117,7 +118,7 @@ class CelebornShuffleReader[K, C](
shuffleId,
handle.shuffleId,
partitionId,
- context.attemptNumber(),
+ encodedAttemptId,
startMapIndex,
endMapIndex,
if (throwsFetchFailure) exceptionMaker else null,
diff --git
a/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
b/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
index 019d3f48e..21aa1074c 100644
---
a/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
+++
b/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.shuffle.celeborn
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.internal.SQLConf
import org.junit
+import org.junit.Assert
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
@@ -67,4 +68,33 @@ class SparkShuffleManagerSuite extends Logging {
// scalastyle:on println
sc.stop()
}
+
+ @junit.Test
+ def testWrongSparkConfMaxAttemptLimit(): Unit = {
+ val conf = new SparkConf().setIfMissing("spark.master", "local")
+ .setIfMissing(
+ "spark.shuffle.manager",
+ "org.apache.spark.shuffle.celeborn.SparkShuffleManager")
+ .set(s"spark.${CelebornConf.MASTER_ENDPOINTS.key}", "localhost:9097")
+ .set(s"spark.${CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key}", "false")
+ .set("spark.shuffle.service.enabled", "false")
+ .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
+
+ // default conf, will success
+ new SparkShuffleManager(conf, true)
+
+ conf
+ .set("spark.stage.maxConsecutiveAttempts", "32768")
+ .set("spark.task.maxFailures", "10")
+ try {
+ new SparkShuffleManager(conf, true)
+ Assert.fail()
+ } catch {
+ case e: IllegalArgumentException =>
+ Assert.assertTrue(
+ e.getMessage.contains("The spark.stage.maxConsecutiveAttempts should
be less than 32768"))
+ case _: Throwable =>
+ Assert.fail()
+ }
+ }
}