This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch branch-0.5
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/branch-0.5 by this push:
     new b9382a045 [CELEBORN-1496] Differentiate map results with only 
different stageAttemptId
b9382a045 is described below

commit b9382a045b0e1c0d04d470d53237b8dee680e2be
Author: jiang13021 <[email protected]>
AuthorDate: Fri Aug 30 09:40:54 2024 +0800

    [CELEBORN-1496] Differentiate map results with only different stageAttemptId
    
    ### What changes were proposed in this pull request?
    Let attemptNumber = (stageAttemptId << 16) | taskAttemptNumber, to 
differentiate map results with only different stageAttemptId.
    
    ### Why are the changes needed?
    If we can't differentiate map tasks with only different stageAttemptId, it 
may lead to mixed reading of two map tasks' shuffle write batches during 
shuffle read, causing data correctness issue.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Add ut: 
org.apache.spark.shuffle.celeborn.SparkShuffleManagerSuite#testWrongSparkConf_MaxAttemptLimit
    
    Closes #2609 from jiang13021/spark_stage_attempt_id.
    
    Lead-authored-by: jiang13021 <[email protected]>
    Co-authored-by: Fu Chen <[email protected]>
    Co-authored-by: Shuang <[email protected]>
    Signed-off-by: Shuang <[email protected]>
    (cherry picked from commit 3853075fdd30898c3010136f74cdb1fa939f0f84)
    Signed-off-by: Shuang <[email protected]>
---
 .../spark/shuffle/celeborn/SparkCommonUtils.java   | 53 ++++++++++++++++++++++
 .../shuffle/celeborn/HashBasedShuffleWriter.java   | 16 ++++---
 .../shuffle/celeborn/SortBasedShuffleWriter.java   | 12 +++--
 .../shuffle/celeborn/SparkShuffleManager.java      |  1 +
 .../shuffle/celeborn/CelebornShuffleReader.scala   |  3 +-
 .../celeborn/CelebornShuffleManagerSuite.scala     | 30 ++++++++++++
 .../CelebornColumnarShuffleReaderSuite.scala       | 12 +++--
 .../shuffle/celeborn/HashBasedShuffleWriter.java   | 16 ++++---
 .../shuffle/celeborn/SortBasedShuffleWriter.java   | 12 +++--
 .../shuffle/celeborn/SparkShuffleManager.java      |  1 +
 .../shuffle/celeborn/CelebornShuffleReader.scala   |  3 +-
 .../celeborn/CelebornShuffleManagerSuite.scala     | 30 ++++++++++++
 12 files changed, 160 insertions(+), 29 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java
new file mode 100644
index 000000000..f00fe063e
--- /dev/null
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.scheduler.DAGScheduler;
+
+public class SparkCommonUtils {
+  public static void validateAttemptConfig(SparkConf conf) throws 
IllegalArgumentException {
+    int maxStageAttempts =
+        conf.getInt(
+            "spark.stage.maxConsecutiveAttempts",
+            DAGScheduler.DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS());
+    // In Spark 2, the parameter is referred to as MAX_TASK_FAILURES, while in 
Spark 3, it has been
+    // changed to TASK_MAX_FAILURES. The default value for both is 
consistently set to 4.
+    int maxTaskAttempts = conf.getInt("spark.task.maxFailures", 4);
+    if (maxStageAttempts >= (1 << 15) || maxTaskAttempts >= (1 << 16)) {
+      // The map attemptId is a non-negative number constructed from
+      // both stageAttemptNumber and taskAttemptNumber.
+      // The high 16 bits of the map attemptId are used for the 
stageAttemptNumber,
+      // and the low 16 bits are used for the taskAttemptNumber.
+      // So spark.stage.maxConsecutiveAttempts should be less than 32768 (1 << 
15)
+      // and spark.task.maxFailures should be less than 65536 (1 << 16).
+      throw new IllegalArgumentException(
+          "The spark.stage.maxConsecutiveAttempts should be less than 32768 
(currently "
+              + maxStageAttempts
+              + ")"
+              + "and spark.task.maxFailures should be less than 65536 
(currently "
+              + maxTaskAttempts
+              + ").");
+    }
+  }
+
+  public static int getEncodedAttemptNumber(TaskContext context) {
+    return (context.stageAttemptNumber() << 16) | context.attemptNumber();
+  }
+}
diff --git 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
index fd83ec2e1..6db620b41 100644
--- 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
+++ 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
@@ -68,6 +68,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private final ShuffleWriteMetrics writeMetrics;
   private final int shuffleId;
   private final int mapId;
+  private final int encodedAttemptId;
   private final TaskContext taskContext;
   private final ShuffleClient shuffleClient;
   private final int numMappers;
@@ -112,6 +113,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.mapId = mapId;
     this.dep = handle.dependency();
     this.shuffleId = shuffleId;
+    this.encodedAttemptId = 
SparkCommonUtils.getEncodedAttemptNumber(taskContext);
     SerializerInstance serializer = dep.serializer().newInstance();
     this.partitioner = dep.partitioner();
     this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
@@ -146,7 +148,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
           new DataPusher(
               shuffleId,
               mapId,
-              taskContext.attemptNumber(),
+              encodedAttemptId,
               taskContext.taskAttemptId(),
               numMappers,
               numPartitions,
@@ -279,7 +281,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         shuffleClient.pushData(
             shuffleId,
             mapId,
-            taskContext.attemptNumber(),
+            encodedAttemptId,
             partitionId,
             buffer,
             0,
@@ -333,7 +335,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     // here we wait for all the in-flight batches to return which sent by 
dataPusher thread
     dataPusher.waitOnTermination();
     sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
-    shuffleClient.prepareForMergeData(shuffleId, mapId, 
taskContext.attemptNumber());
+    shuffleClient.prepareForMergeData(shuffleId, mapId, encodedAttemptId);
 
     // merge and push residual data to reduce network traffic
     // NB: since dataPusher thread have no in-flight data at this point,
@@ -345,7 +347,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
             shuffleClient.mergeData(
                 shuffleId,
                 mapId,
-                taskContext.attemptNumber(),
+                encodedAttemptId,
                 i,
                 sendBuffers[i],
                 0,
@@ -358,7 +360,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         writeMetrics.incBytesWritten(bytesWritten);
       }
     }
-    shuffleClient.pushMergedData(shuffleId, mapId, 
taskContext.attemptNumber());
+    shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
 
     updateMapStatus();
 
@@ -367,7 +369,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     sendOffsets = null;
 
     long waitStartTime = System.nanoTime();
-    shuffleClient.mapperEnd(shuffleId, mapId, taskContext.attemptNumber(), 
numMappers);
+    shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
     writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
 
     BlockManagerId bmId = SparkEnv.get().blockManager().shuffleServerId();
@@ -404,7 +406,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         }
       }
     } finally {
-      shuffleClient.cleanup(shuffleId, mapId, taskContext.attemptNumber());
+      shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
     }
   }
 }
diff --git 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
index 7b8baaf06..2d65b6859 100644
--- 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
+++ 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
@@ -62,6 +62,7 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private final ShuffleWriteMetrics writeMetrics;
   private final int shuffleId;
   private final int mapId;
+  private final int encodedAttemptId;
   private final TaskContext taskContext;
   private final ShuffleClient shuffleClient;
   private final int numMappers;
@@ -102,6 +103,7 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.mapId = taskContext.partitionId();
     this.dep = dep;
     this.shuffleId = shuffleId;
+    this.encodedAttemptId = 
SparkCommonUtils.getEncodedAttemptNumber(taskContext);
     SerializerInstance serializer = dep.serializer().newInstance();
     this.partitioner = dep.partitioner();
     this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
@@ -130,7 +132,7 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
             taskContext,
             shuffleId,
             mapId,
-            taskContext.attemptNumber(),
+            encodedAttemptId,
             taskContext.taskAttemptId(),
             numMappers,
             numPartitions,
@@ -285,7 +287,7 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         shuffleClient.pushData(
             shuffleId,
             mapId,
-            taskContext.attemptNumber(),
+            encodedAttemptId,
             partitionId,
             buffer,
             0,
@@ -309,12 +311,12 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     pusher.close(true);
     writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);
 
-    shuffleClient.pushMergedData(shuffleId, mapId, 
taskContext.attemptNumber());
+    shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
 
     updateMapStatus();
 
     long waitStartTime = System.nanoTime();
-    shuffleClient.mapperEnd(shuffleId, mapId, taskContext.attemptNumber(), 
numMappers);
+    shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
     writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
   }
 
@@ -350,7 +352,7 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     } catch (IOException e) {
       return Option.apply(null);
     } finally {
-      shuffleClient.cleanup(shuffleId, mapId, taskContext.attemptNumber());
+      shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
     }
   }
 
diff --git 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index bb0bd447b..aa5549184 100644
--- 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++ 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -66,6 +66,7 @@ public class SparkShuffleManager implements ShuffleManager {
   private ExecutorShuffleIdTracker shuffleIdTracker = new 
ExecutorShuffleIdTracker();
 
   public SparkShuffleManager(SparkConf conf, boolean isDriver) {
+    SparkCommonUtils.validateAttemptConfig(conf);
     this.conf = conf;
     this.isDriver = isDriver;
     this.celebornConf = SparkUtils.fromSparkConf(conf);
diff --git 
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
 
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index 512d8f41c..ac74f92d6 100644
--- 
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++ 
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -56,6 +56,7 @@ class CelebornShuffleReader[K, C](
     handle.extension)
 
   private val exceptionRef = new AtomicReference[IOException]
+  private val encodedAttemptId = 
SparkCommonUtils.getEncodedAttemptNumber(context)
 
   override def read(): Iterator[Product2[K, C]] = {
 
@@ -96,7 +97,7 @@ class CelebornShuffleReader[K, C](
               val inputStream = shuffleClient.readPartition(
                 shuffleId,
                 partitionId,
-                context.attemptNumber(),
+                encodedAttemptId,
                 startMapIndex,
                 endMapIndex,
                 metricsCallback)
diff --git 
a/client-spark/spark-2/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
 
b/client-spark/spark-2/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
index cb6c3e0ab..0786904b8 100644
--- 
a/client-spark/spark-2/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
+++ 
b/client-spark/spark-2/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.shuffle.celeborn
 import org.apache.spark.{SparkConf, SparkContext}
 import org.apache.spark.sql.internal.SQLConf
 import org.junit
+import org.junit.Assert
 import org.junit.runner.RunWith
 import org.junit.runners.JUnit4
 
@@ -67,4 +68,33 @@ class SparkShuffleManagerSuite extends Logging {
     sc.stop()
   }
 
+  @junit.Test
+  def testWrongSparkConfMaxAttemptLimit(): Unit = {
+    val conf = new SparkConf().setIfMissing("spark.master", "local")
+      .setIfMissing(
+        "spark.shuffle.manager",
+        "org.apache.spark.shuffle.celeborn.SparkShuffleManager")
+      .set(s"spark.${CelebornConf.MASTER_ENDPOINTS.key}", "localhost:9097")
+      .set(s"spark.${CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key}", "false")
+      .set("spark.shuffle.service.enabled", "false")
+      .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
+
+    // default conf, will success
+    new SparkShuffleManager(conf, true)
+
+    conf
+      .set("spark.stage.maxConsecutiveAttempts", "32768")
+      .set("spark.task.maxFailures", "10")
+    try {
+      new SparkShuffleManager(conf, true)
+      Assert.fail()
+    } catch {
+      case e: IllegalArgumentException =>
+        Assert.assertTrue(
+          e.getMessage.contains("The spark.stage.maxConsecutiveAttempts should 
be less than 32768"))
+      case _: Throwable =>
+        Assert.fail()
+    }
+  }
+
 }
diff --git 
a/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
 
b/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
index 5df434f54..158cb9406 100644
--- 
a/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
+++ 
b/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.shuffle.celeborn
 
-import org.apache.spark.{ShuffleDependency, SparkConf}
+import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext}
 import org.apache.spark.serializer.{KryoSerializer, KryoSerializerInstance}
 import org.apache.spark.sql.execution.UnsafeRowSerializer
 import 
org.apache.spark.sql.execution.columnar.CelebornColumnarBatchSerializerInstance
@@ -45,6 +45,9 @@ class CelebornColumnarShuffleReaderSuite {
 
     var shuffleClient: MockedStatic[ShuffleClient] = null
     try {
+      val taskContext = Mockito.mock(classOf[TaskContext])
+      Mockito.when(taskContext.stageAttemptNumber).thenReturn(0)
+      Mockito.when(taskContext.attemptNumber).thenReturn(0)
       shuffleClient = Mockito.mockStatic(classOf[ShuffleClient])
       val shuffleReader = SparkUtils.createColumnarShuffleReader(
         handle,
@@ -52,7 +55,7 @@ class CelebornColumnarShuffleReaderSuite {
         10,
         0,
         10,
-        null,
+        taskContext,
         new CelebornConf(),
         null,
         new ExecutorShuffleIdTracker())
@@ -68,6 +71,9 @@ class CelebornColumnarShuffleReaderSuite {
   def columnarShuffleReaderNewSerializerInstance(): Unit = {
     var shuffleClient: MockedStatic[ShuffleClient] = null
     try {
+      val taskContext = Mockito.mock(classOf[TaskContext])
+      Mockito.when(taskContext.stageAttemptNumber).thenReturn(0)
+      Mockito.when(taskContext.attemptNumber).thenReturn(0)
       shuffleClient = Mockito.mockStatic(classOf[ShuffleClient])
       val shuffleReader = SparkUtils.createColumnarShuffleReader(
         new CelebornShuffleHandle[Int, String, String](
@@ -83,7 +89,7 @@ class CelebornColumnarShuffleReaderSuite {
         10,
         0,
         10,
-        null,
+        taskContext,
         new CelebornConf(),
         null,
         new ExecutorShuffleIdTracker())
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
index bd1eeb16c..7ed869844 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
@@ -68,6 +68,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private final ShuffleWriteMetricsReporter writeMetrics;
   private final int shuffleId;
   private final int mapId;
+  private final int encodedAttemptId;
   private final TaskContext taskContext;
   private final ShuffleClient shuffleClient;
   private final int numMappers;
@@ -112,6 +113,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.mapId = taskContext.partitionId();
     this.dep = handle.dependency();
     this.shuffleId = shuffleId;
+    this.encodedAttemptId = 
SparkCommonUtils.getEncodedAttemptNumber(taskContext);
     SerializerInstance serializer = dep.serializer().newInstance();
     this.partitioner = dep.partitioner();
     this.writeMetrics = metrics;
@@ -142,7 +144,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
           new DataPusher(
               shuffleId,
               mapId,
-              taskContext.attemptNumber(),
+              encodedAttemptId,
               taskContext.taskAttemptId(),
               numMappers,
               numPartitions,
@@ -279,7 +281,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         shuffleClient.pushData(
             shuffleId,
             mapId,
-            taskContext.attemptNumber(),
+            encodedAttemptId,
             partitionId,
             buffer,
             0,
@@ -343,7 +345,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         shuffleClient.mergeData(
             shuffleId,
             mapId,
-            taskContext.attemptNumber(),
+            encodedAttemptId,
             partitionId,
             buffer,
             offset,
@@ -368,14 +370,14 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     long pushMergedDataTime = System.nanoTime();
     dataPusher.waitOnTermination();
     sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
-    shuffleClient.prepareForMergeData(shuffleId, mapId, 
taskContext.attemptNumber());
+    shuffleClient.prepareForMergeData(shuffleId, mapId, encodedAttemptId);
     closeWrite();
-    shuffleClient.pushMergedData(shuffleId, mapId, 
taskContext.attemptNumber());
+    shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
     writeMetrics.incWriteTime(System.nanoTime() - pushMergedDataTime);
     updateRecordsWrittenMetrics();
 
     long waitStartTime = System.nanoTime();
-    shuffleClient.mapperEnd(shuffleId, mapId, taskContext.attemptNumber(), 
numMappers);
+    shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
     writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
 
     BlockManagerId bmId = SparkEnv.get().blockManager().shuffleServerId();
@@ -408,7 +410,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         }
       }
     } finally {
-      shuffleClient.cleanup(shuffleId, mapId, taskContext.attemptNumber());
+      shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
     }
   }
 
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
index 1c8900bb2..5717910ee 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
@@ -62,6 +62,7 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private final ShuffleWriteMetricsReporter writeMetrics;
   private final int shuffleId;
   private final int mapId;
+  private final int encodedAttemptId;
   private final TaskContext taskContext;
   private final ShuffleClient shuffleClient;
   private final int numMappers;
@@ -116,6 +117,7 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.mapId = taskContext.partitionId();
     this.dep = dep;
     this.shuffleId = shuffleId;
+    this.encodedAttemptId = 
SparkCommonUtils.getEncodedAttemptNumber(taskContext);
     SerializerInstance serializer = dep.serializer().newInstance();
     this.partitioner = dep.partitioner();
     this.writeMetrics = metrics;
@@ -143,7 +145,7 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
               taskContext,
               shuffleId,
               mapId,
-              taskContext.attemptNumber(),
+              encodedAttemptId,
               taskContext.taskAttemptId(),
               numMappers,
               numPartitions,
@@ -348,7 +350,7 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         shuffleClient.pushData(
             shuffleId,
             mapId,
-            taskContext.attemptNumber(),
+            encodedAttemptId,
             partitionId,
             buffer,
             0,
@@ -371,12 +373,12 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     pusher.pushData(false);
     pusher.close(true);
 
-    shuffleClient.pushMergedData(shuffleId, mapId, 
taskContext.attemptNumber());
+    shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
     writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);
     writeMetrics.incRecordsWritten(tmpRecordsWritten);
 
     long waitStartTime = System.nanoTime();
-    shuffleClient.mapperEnd(shuffleId, mapId, taskContext.attemptNumber(), 
numMappers);
+    shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
     writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
   }
 
@@ -403,7 +405,7 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         }
       }
     } finally {
-      shuffleClient.cleanup(shuffleId, mapId, taskContext.attemptNumber());
+      shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
     }
   }
 
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index 5d0a2af28..1a3e5cca1 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -107,6 +107,7 @@ public class SparkShuffleManager implements ShuffleManager {
           key,
           defaultValue);
     }
+    SparkCommonUtils.validateAttemptConfig(conf);
     this.conf = conf;
     this.isDriver = isDriver;
     this.celebornConf = SparkUtils.fromSparkConf(conf);
diff --git 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index f2cf2c21b..c064fea40 100644
--- 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++ 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -66,6 +66,7 @@ class CelebornShuffleReader[K, C](
 
   private val exceptionRef = new AtomicReference[IOException]
   private val throwsFetchFailure = handle.throwsFetchFailure
+  private val encodedAttemptId = 
SparkCommonUtils.getEncodedAttemptNumber(context)
 
   override def read(): Iterator[Product2[K, C]] = {
 
@@ -193,7 +194,7 @@ class CelebornShuffleReader[K, C](
             shuffleId,
             handle.shuffleId,
             partitionId,
-            context.attemptNumber(),
+            encodedAttemptId,
             startMapIndex,
             endMapIndex,
             if (throwsFetchFailure) 
ExceptionMakerHelper.SHUFFLE_FETCH_FAILURE_EXCEPTION_MAKER
diff --git 
a/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
 
b/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
index 4e998ffc7..7b790cbbf 100644
--- 
a/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
+++ 
b/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.shuffle.celeborn
 import org.apache.spark.{SparkConf, SparkContext}
 import org.apache.spark.sql.internal.SQLConf
 import org.junit
+import org.junit.Assert
 import org.junit.runner.RunWith
 import org.junit.runners.JUnit4
 
@@ -89,4 +90,33 @@ class SparkShuffleManagerSuite extends Logging {
     // scalastyle:on println
     sc.stop()
   }
+
+  @junit.Test
+  def testWrongSparkConfMaxAttemptLimit(): Unit = {
+    val conf = new SparkConf().setIfMissing("spark.master", "local")
+      .setIfMissing(
+        "spark.shuffle.manager",
+        "org.apache.spark.shuffle.celeborn.SparkShuffleManager")
+      .set(s"spark.${CelebornConf.MASTER_ENDPOINTS.key}", "localhost:9097")
+      .set(s"spark.${CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key}", "false")
+      .set("spark.shuffle.service.enabled", "false")
+      .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
+
+    // default conf, will success
+    new SparkShuffleManager(conf, true)
+
+    conf
+      .set("spark.stage.maxConsecutiveAttempts", "32768")
+      .set("spark.task.maxFailures", "10")
+    try {
+      new SparkShuffleManager(conf, true)
+      Assert.fail()
+    } catch {
+      case e: IllegalArgumentException =>
+        Assert.assertTrue(
+          e.getMessage.contains("The spark.stage.maxConsecutiveAttempts should 
be less than 32768"))
+      case _: Throwable =>
+        Assert.fail()
+    }
+  }
 }

Reply via email to