This is an automated email from the ASF dual-hosted git repository.
rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 3853075fd [CELEBORN-1496] Differentiate map results with only
different stageAttemptId
3853075fd is described below
commit 3853075fdd30898c3010136f74cdb1fa939f0f84
Author: jiang13021 <[email protected]>
AuthorDate: Fri Aug 30 09:40:54 2024 +0800
[CELEBORN-1496] Differentiate map results with only different stageAttemptId
### What changes were proposed in this pull request?
Let attemptNumber = (stageAttemptId << 16) | taskAttemptNumber, to
differentiate map results with only different stageAttemptId.
### Why are the changes needed?
If we can't differentiate map tasks with only different stageAttemptId, it
may lead to mixed reading of two map tasks' shuffle write batches during
shuffle read, causing data correctness issue.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Add ut:
org.apache.spark.shuffle.celeborn.SparkShuffleManagerSuite#testWrongSparkConf_MaxAttemptLimit
Closes #2609 from jiang13021/spark_stage_attempt_id.
Lead-authored-by: jiang13021 <[email protected]>
Co-authored-by: Fu Chen <[email protected]>
Co-authored-by: Shuang <[email protected]>
Signed-off-by: Shuang <[email protected]>
---
.../spark/shuffle/celeborn/SparkCommonUtils.java | 53 ++++++++++++++++++++++
.../shuffle/celeborn/HashBasedShuffleWriter.java | 16 ++++---
.../shuffle/celeborn/SortBasedShuffleWriter.java | 12 +++--
.../shuffle/celeborn/SparkShuffleManager.java | 1 +
.../shuffle/celeborn/CelebornShuffleReader.scala | 3 +-
.../celeborn/CelebornShuffleManagerSuite.scala | 30 ++++++++++++
.../CelebornColumnarShuffleReaderSuite.scala | 12 +++--
.../shuffle/celeborn/HashBasedShuffleWriter.java | 16 ++++---
.../shuffle/celeborn/SortBasedShuffleWriter.java | 12 +++--
.../shuffle/celeborn/SparkShuffleManager.java | 1 +
.../shuffle/celeborn/CelebornShuffleReader.scala | 3 +-
.../celeborn/CelebornShuffleManagerSuite.scala | 30 ++++++++++++
12 files changed, 160 insertions(+), 29 deletions(-)
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java
new file mode 100644
index 000000000..f00fe063e
--- /dev/null
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.scheduler.DAGScheduler;
+
+public class SparkCommonUtils {
+ public static void validateAttemptConfig(SparkConf conf) throws
IllegalArgumentException {
+ int maxStageAttempts =
+ conf.getInt(
+ "spark.stage.maxConsecutiveAttempts",
+ DAGScheduler.DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS());
+ // In Spark 2, the parameter is referred to as MAX_TASK_FAILURES, while in
Spark 3, it has been
+ // changed to TASK_MAX_FAILURES. The default value for both is
consistently set to 4.
+ int maxTaskAttempts = conf.getInt("spark.task.maxFailures", 4);
+ if (maxStageAttempts >= (1 << 15) || maxTaskAttempts >= (1 << 16)) {
+ // The map attemptId is a non-negative number constructed from
+ // both stageAttemptNumber and taskAttemptNumber.
+ // The high 16 bits of the map attemptId are used for the
stageAttemptNumber,
+ // and the low 16 bits are used for the taskAttemptNumber.
+ // So spark.stage.maxConsecutiveAttempts should be less than 32768 (1 <<
15)
+ // and spark.task.maxFailures should be less than 65536 (1 << 16).
+ throw new IllegalArgumentException(
+ "The spark.stage.maxConsecutiveAttempts should be less than 32768
(currently "
+ + maxStageAttempts
+ + ")"
+ + "and spark.task.maxFailures should be less than 65536
(currently "
+ + maxTaskAttempts
+ + ").");
+ }
+ }
+
+ public static int getEncodedAttemptNumber(TaskContext context) {
+ return (context.stageAttemptNumber() << 16) | context.attemptNumber();
+ }
+}
diff --git
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
index fd83ec2e1..6db620b41 100644
---
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
+++
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
@@ -68,6 +68,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
private final ShuffleWriteMetrics writeMetrics;
private final int shuffleId;
private final int mapId;
+ private final int encodedAttemptId;
private final TaskContext taskContext;
private final ShuffleClient shuffleClient;
private final int numMappers;
@@ -112,6 +113,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
this.mapId = mapId;
this.dep = handle.dependency();
this.shuffleId = shuffleId;
+ this.encodedAttemptId =
SparkCommonUtils.getEncodedAttemptNumber(taskContext);
SerializerInstance serializer = dep.serializer().newInstance();
this.partitioner = dep.partitioner();
this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
@@ -146,7 +148,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
new DataPusher(
shuffleId,
mapId,
- taskContext.attemptNumber(),
+ encodedAttemptId,
taskContext.taskAttemptId(),
numMappers,
numPartitions,
@@ -279,7 +281,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
shuffleClient.pushData(
shuffleId,
mapId,
- taskContext.attemptNumber(),
+ encodedAttemptId,
partitionId,
buffer,
0,
@@ -333,7 +335,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
// here we wait for all the in-flight batches to return which sent by
dataPusher thread
dataPusher.waitOnTermination();
sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
- shuffleClient.prepareForMergeData(shuffleId, mapId,
taskContext.attemptNumber());
+ shuffleClient.prepareForMergeData(shuffleId, mapId, encodedAttemptId);
// merge and push residual data to reduce network traffic
// NB: since dataPusher thread have no in-flight data at this point,
@@ -345,7 +347,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
shuffleClient.mergeData(
shuffleId,
mapId,
- taskContext.attemptNumber(),
+ encodedAttemptId,
i,
sendBuffers[i],
0,
@@ -358,7 +360,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
writeMetrics.incBytesWritten(bytesWritten);
}
}
- shuffleClient.pushMergedData(shuffleId, mapId,
taskContext.attemptNumber());
+ shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
updateMapStatus();
@@ -367,7 +369,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
sendOffsets = null;
long waitStartTime = System.nanoTime();
- shuffleClient.mapperEnd(shuffleId, mapId, taskContext.attemptNumber(),
numMappers);
+ shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
BlockManagerId bmId = SparkEnv.get().blockManager().shuffleServerId();
@@ -404,7 +406,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
}
}
} finally {
- shuffleClient.cleanup(shuffleId, mapId, taskContext.attemptNumber());
+ shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
}
}
}
diff --git
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
index 7b8baaf06..2d65b6859 100644
---
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
+++
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
@@ -62,6 +62,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
private final ShuffleWriteMetrics writeMetrics;
private final int shuffleId;
private final int mapId;
+ private final int encodedAttemptId;
private final TaskContext taskContext;
private final ShuffleClient shuffleClient;
private final int numMappers;
@@ -102,6 +103,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
this.mapId = taskContext.partitionId();
this.dep = dep;
this.shuffleId = shuffleId;
+ this.encodedAttemptId =
SparkCommonUtils.getEncodedAttemptNumber(taskContext);
SerializerInstance serializer = dep.serializer().newInstance();
this.partitioner = dep.partitioner();
this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
@@ -130,7 +132,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
taskContext,
shuffleId,
mapId,
- taskContext.attemptNumber(),
+ encodedAttemptId,
taskContext.taskAttemptId(),
numMappers,
numPartitions,
@@ -285,7 +287,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
shuffleClient.pushData(
shuffleId,
mapId,
- taskContext.attemptNumber(),
+ encodedAttemptId,
partitionId,
buffer,
0,
@@ -309,12 +311,12 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
pusher.close(true);
writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);
- shuffleClient.pushMergedData(shuffleId, mapId,
taskContext.attemptNumber());
+ shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
updateMapStatus();
long waitStartTime = System.nanoTime();
- shuffleClient.mapperEnd(shuffleId, mapId, taskContext.attemptNumber(),
numMappers);
+ shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
}
@@ -350,7 +352,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
} catch (IOException e) {
return Option.apply(null);
} finally {
- shuffleClient.cleanup(shuffleId, mapId, taskContext.attemptNumber());
+ shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
}
}
diff --git
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index 83280c748..4f6e835e7 100644
---
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -66,6 +66,7 @@ public class SparkShuffleManager implements ShuffleManager {
private ExecutorShuffleIdTracker shuffleIdTracker = new
ExecutorShuffleIdTracker();
public SparkShuffleManager(SparkConf conf, boolean isDriver) {
+ SparkCommonUtils.validateAttemptConfig(conf);
this.conf = conf;
this.isDriver = isDriver;
this.celebornConf = SparkUtils.fromSparkConf(conf);
diff --git
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index 9cabfebc9..4a3092275 100644
---
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -56,6 +56,7 @@ class CelebornShuffleReader[K, C](
handle.extension)
private val exceptionRef = new AtomicReference[IOException]
+ private val encodedAttemptId =
SparkCommonUtils.getEncodedAttemptNumber(context)
override def read(): Iterator[Product2[K, C]] = {
@@ -96,7 +97,7 @@ class CelebornShuffleReader[K, C](
val inputStream = shuffleClient.readPartition(
shuffleId,
partitionId,
- context.attemptNumber(),
+ encodedAttemptId,
startMapIndex,
endMapIndex,
metricsCallback)
diff --git
a/client-spark/spark-2/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
b/client-spark/spark-2/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
index cb6c3e0ab..0786904b8 100644
---
a/client-spark/spark-2/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
+++
b/client-spark/spark-2/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.shuffle.celeborn
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.internal.SQLConf
import org.junit
+import org.junit.Assert
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
@@ -67,4 +68,33 @@ class SparkShuffleManagerSuite extends Logging {
sc.stop()
}
+ @junit.Test
+ def testWrongSparkConfMaxAttemptLimit(): Unit = {
+ val conf = new SparkConf().setIfMissing("spark.master", "local")
+ .setIfMissing(
+ "spark.shuffle.manager",
+ "org.apache.spark.shuffle.celeborn.SparkShuffleManager")
+ .set(s"spark.${CelebornConf.MASTER_ENDPOINTS.key}", "localhost:9097")
+ .set(s"spark.${CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key}", "false")
+ .set("spark.shuffle.service.enabled", "false")
+ .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
+
+ // default conf, will success
+ new SparkShuffleManager(conf, true)
+
+ conf
+ .set("spark.stage.maxConsecutiveAttempts", "32768")
+ .set("spark.task.maxFailures", "10")
+ try {
+ new SparkShuffleManager(conf, true)
+ Assert.fail()
+ } catch {
+ case e: IllegalArgumentException =>
+ Assert.assertTrue(
+ e.getMessage.contains("The spark.stage.maxConsecutiveAttempts should
be less than 32768"))
+ case _: Throwable =>
+ Assert.fail()
+ }
+ }
+
}
diff --git
a/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
b/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
index 5df434f54..158cb9406 100644
---
a/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
+++
b/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.shuffle.celeborn
-import org.apache.spark.{ShuffleDependency, SparkConf}
+import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext}
import org.apache.spark.serializer.{KryoSerializer, KryoSerializerInstance}
import org.apache.spark.sql.execution.UnsafeRowSerializer
import
org.apache.spark.sql.execution.columnar.CelebornColumnarBatchSerializerInstance
@@ -45,6 +45,9 @@ class CelebornColumnarShuffleReaderSuite {
var shuffleClient: MockedStatic[ShuffleClient] = null
try {
+ val taskContext = Mockito.mock(classOf[TaskContext])
+ Mockito.when(taskContext.stageAttemptNumber).thenReturn(0)
+ Mockito.when(taskContext.attemptNumber).thenReturn(0)
shuffleClient = Mockito.mockStatic(classOf[ShuffleClient])
val shuffleReader = SparkUtils.createColumnarShuffleReader(
handle,
@@ -52,7 +55,7 @@ class CelebornColumnarShuffleReaderSuite {
10,
0,
10,
- null,
+ taskContext,
new CelebornConf(),
null,
new ExecutorShuffleIdTracker())
@@ -68,6 +71,9 @@ class CelebornColumnarShuffleReaderSuite {
def columnarShuffleReaderNewSerializerInstance(): Unit = {
var shuffleClient: MockedStatic[ShuffleClient] = null
try {
+ val taskContext = Mockito.mock(classOf[TaskContext])
+ Mockito.when(taskContext.stageAttemptNumber).thenReturn(0)
+ Mockito.when(taskContext.attemptNumber).thenReturn(0)
shuffleClient = Mockito.mockStatic(classOf[ShuffleClient])
val shuffleReader = SparkUtils.createColumnarShuffleReader(
new CelebornShuffleHandle[Int, String, String](
@@ -83,7 +89,7 @@ class CelebornColumnarShuffleReaderSuite {
10,
0,
10,
- null,
+ taskContext,
new CelebornConf(),
null,
new ExecutorShuffleIdTracker())
diff --git
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
index bd1eeb16c..7ed869844 100644
---
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
+++
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
@@ -68,6 +68,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
private final ShuffleWriteMetricsReporter writeMetrics;
private final int shuffleId;
private final int mapId;
+ private final int encodedAttemptId;
private final TaskContext taskContext;
private final ShuffleClient shuffleClient;
private final int numMappers;
@@ -112,6 +113,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
this.mapId = taskContext.partitionId();
this.dep = handle.dependency();
this.shuffleId = shuffleId;
+ this.encodedAttemptId =
SparkCommonUtils.getEncodedAttemptNumber(taskContext);
SerializerInstance serializer = dep.serializer().newInstance();
this.partitioner = dep.partitioner();
this.writeMetrics = metrics;
@@ -142,7 +144,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
new DataPusher(
shuffleId,
mapId,
- taskContext.attemptNumber(),
+ encodedAttemptId,
taskContext.taskAttemptId(),
numMappers,
numPartitions,
@@ -279,7 +281,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
shuffleClient.pushData(
shuffleId,
mapId,
- taskContext.attemptNumber(),
+ encodedAttemptId,
partitionId,
buffer,
0,
@@ -343,7 +345,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
shuffleClient.mergeData(
shuffleId,
mapId,
- taskContext.attemptNumber(),
+ encodedAttemptId,
partitionId,
buffer,
offset,
@@ -368,14 +370,14 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
long pushMergedDataTime = System.nanoTime();
dataPusher.waitOnTermination();
sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
- shuffleClient.prepareForMergeData(shuffleId, mapId,
taskContext.attemptNumber());
+ shuffleClient.prepareForMergeData(shuffleId, mapId, encodedAttemptId);
closeWrite();
- shuffleClient.pushMergedData(shuffleId, mapId,
taskContext.attemptNumber());
+ shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
writeMetrics.incWriteTime(System.nanoTime() - pushMergedDataTime);
updateRecordsWrittenMetrics();
long waitStartTime = System.nanoTime();
- shuffleClient.mapperEnd(shuffleId, mapId, taskContext.attemptNumber(),
numMappers);
+ shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
BlockManagerId bmId = SparkEnv.get().blockManager().shuffleServerId();
@@ -408,7 +410,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
}
}
} finally {
- shuffleClient.cleanup(shuffleId, mapId, taskContext.attemptNumber());
+ shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
}
}
diff --git
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
index 1c8900bb2..5717910ee 100644
---
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
+++
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
@@ -62,6 +62,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
private final ShuffleWriteMetricsReporter writeMetrics;
private final int shuffleId;
private final int mapId;
+ private final int encodedAttemptId;
private final TaskContext taskContext;
private final ShuffleClient shuffleClient;
private final int numMappers;
@@ -116,6 +117,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
this.mapId = taskContext.partitionId();
this.dep = dep;
this.shuffleId = shuffleId;
+ this.encodedAttemptId =
SparkCommonUtils.getEncodedAttemptNumber(taskContext);
SerializerInstance serializer = dep.serializer().newInstance();
this.partitioner = dep.partitioner();
this.writeMetrics = metrics;
@@ -143,7 +145,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
taskContext,
shuffleId,
mapId,
- taskContext.attemptNumber(),
+ encodedAttemptId,
taskContext.taskAttemptId(),
numMappers,
numPartitions,
@@ -348,7 +350,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
shuffleClient.pushData(
shuffleId,
mapId,
- taskContext.attemptNumber(),
+ encodedAttemptId,
partitionId,
buffer,
0,
@@ -371,12 +373,12 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
pusher.pushData(false);
pusher.close(true);
- shuffleClient.pushMergedData(shuffleId, mapId,
taskContext.attemptNumber());
+ shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);
writeMetrics.incRecordsWritten(tmpRecordsWritten);
long waitStartTime = System.nanoTime();
- shuffleClient.mapperEnd(shuffleId, mapId, taskContext.attemptNumber(),
numMappers);
+ shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
}
@@ -403,7 +405,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
}
}
} finally {
- shuffleClient.cleanup(shuffleId, mapId, taskContext.attemptNumber());
+ shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
}
}
diff --git
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index 4c84e9d53..da785886c 100644
---
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -107,6 +107,7 @@ public class SparkShuffleManager implements ShuffleManager {
key,
defaultValue);
}
+ SparkCommonUtils.validateAttemptConfig(conf);
this.conf = conf;
this.isDriver = isDriver;
this.celebornConf = SparkUtils.fromSparkConf(conf);
diff --git
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index 402604439..7cfa0a324 100644
---
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -66,6 +66,7 @@ class CelebornShuffleReader[K, C](
private val exceptionRef = new AtomicReference[IOException]
private val throwsFetchFailure = handle.throwsFetchFailure
+ private val encodedAttemptId =
SparkCommonUtils.getEncodedAttemptNumber(context)
override def read(): Iterator[Product2[K, C]] = {
@@ -193,7 +194,7 @@ class CelebornShuffleReader[K, C](
shuffleId,
handle.shuffleId,
partitionId,
- context.attemptNumber(),
+ encodedAttemptId,
startMapIndex,
endMapIndex,
if (throwsFetchFailure)
ExceptionMakerHelper.SHUFFLE_FETCH_FAILURE_EXCEPTION_MAKER
diff --git
a/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
b/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
index 4e998ffc7..7b790cbbf 100644
---
a/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
+++
b/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.shuffle.celeborn
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.internal.SQLConf
import org.junit
+import org.junit.Assert
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
@@ -89,4 +90,33 @@ class SparkShuffleManagerSuite extends Logging {
// scalastyle:on println
sc.stop()
}
+
+ @junit.Test
+ def testWrongSparkConfMaxAttemptLimit(): Unit = {
+ val conf = new SparkConf().setIfMissing("spark.master", "local")
+ .setIfMissing(
+ "spark.shuffle.manager",
+ "org.apache.spark.shuffle.celeborn.SparkShuffleManager")
+ .set(s"spark.${CelebornConf.MASTER_ENDPOINTS.key}", "localhost:9097")
+ .set(s"spark.${CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key}", "false")
+ .set("spark.shuffle.service.enabled", "false")
+ .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
+
+ // default conf, will success
+ new SparkShuffleManager(conf, true)
+
+ conf
+ .set("spark.stage.maxConsecutiveAttempts", "32768")
+ .set("spark.task.maxFailures", "10")
+ try {
+ new SparkShuffleManager(conf, true)
+ Assert.fail()
+ } catch {
+ case e: IllegalArgumentException =>
+ Assert.assertTrue(
+ e.getMessage.contains("The spark.stage.maxConsecutiveAttempts should
be less than 32768"))
+ case _: Throwable =>
+ Assert.fail()
+ }
+ }
}