This is an automated email from the ASF dual-hosted git repository.

fchen pushed a commit to branch branch-0.4
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/branch-0.4 by this push:
     new f7760359a [CELEBORN-1496][0.4] Differentiate map results with only 
different stageAttemptId
f7760359a is described below

commit f7760359a5d92f028852a3435b587ab6a6e2f529
Author: jiang13021 <[email protected]>
AuthorDate: Mon Sep 2 13:45:57 2024 +0800

    [CELEBORN-1496][0.4] Differentiate map results with only different 
stageAttemptId
    
    backport https://github.com/apache/celeborn/pull/2609 to branch-0.4
    
    ### What changes were proposed in this pull request?
    
    Let attemptNumber = (stageAttemptId << 16) | taskAttemptNumber, to 
differentiate map results with only different stageAttemptId.
    
    ### Why are the changes needed?
    
    If we can't differentiate map tasks with only different stageAttemptId, it 
may lead to mixed reading of two map tasks' shuffle write batches during 
shuffle read, causing data correctness issue.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Add ut: 
org.apache.spark.shuffle.celeborn.SparkShuffleManagerSuite#testWrongSparkConf_MaxAttemptLimit
    
    Closes #2609 from jiang13021/spark_stage_attempt_id.
    
    Lead-authored-by: jiang13021 <jiangyanze.jyzantgroup.com>
    
    Closes #2717 from cfmcgrady/CELEBORN-1496-branch-0.4.
    
    Authored-by: jiang13021 <[email protected]>
    Signed-off-by: Fu Chen <[email protected]>
---
 .../spark/shuffle/celeborn/SparkCommonUtils.java   | 53 ++++++++++++++++++++++
 .../shuffle/celeborn/HashBasedShuffleWriter.java   | 16 ++++---
 .../shuffle/celeborn/SortBasedShuffleWriter.java   | 12 +++--
 .../shuffle/celeborn/SparkShuffleManager.java      |  1 +
 .../shuffle/celeborn/CelebornShuffleReader.scala   |  3 +-
 .../celeborn/CelebornShuffleManagerSuite.scala     | 30 ++++++++++++
 .../CelebornColumnarShuffleReaderSuite.scala       | 12 +++--
 .../shuffle/celeborn/HashBasedShuffleWriter.java   | 16 ++++---
 .../shuffle/celeborn/SortBasedShuffleWriter.java   | 12 +++--
 .../shuffle/celeborn/SparkShuffleManager.java      |  1 +
 .../shuffle/celeborn/CelebornShuffleReader.scala   |  3 +-
 .../celeborn/CelebornShuffleManagerSuite.scala     | 30 ++++++++++++
 12 files changed, 160 insertions(+), 29 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java
new file mode 100644
index 000000000..f00fe063e
--- /dev/null
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.scheduler.DAGScheduler;
+
+public class SparkCommonUtils {
+  public static void validateAttemptConfig(SparkConf conf) throws 
IllegalArgumentException {
+    int maxStageAttempts =
+        conf.getInt(
+            "spark.stage.maxConsecutiveAttempts",
+            DAGScheduler.DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS());
+    // In Spark 2, the parameter is referred to as MAX_TASK_FAILURES, while in 
Spark 3, it has been
+    // changed to TASK_MAX_FAILURES. The default value for both is 
consistently set to 4.
+    int maxTaskAttempts = conf.getInt("spark.task.maxFailures", 4);
+    if (maxStageAttempts >= (1 << 15) || maxTaskAttempts >= (1 << 16)) {
+      // The map attemptId is a non-negative number constructed from
+      // both stageAttemptNumber and taskAttemptNumber.
+      // The high 16 bits of the map attemptId are used for the 
stageAttemptNumber,
+      // and the low 16 bits are used for the taskAttemptNumber.
+      // So spark.stage.maxConsecutiveAttempts should be less than 32768 (1 << 
15)
+      // and spark.task.maxFailures should be less than 65536 (1 << 16).
+      throw new IllegalArgumentException(
+          "The spark.stage.maxConsecutiveAttempts should be less than 32768 
(currently "
+              + maxStageAttempts
+              + ")"
+              + "and spark.task.maxFailures should be less than 65536 
(currently "
+              + maxTaskAttempts
+              + ").");
+    }
+  }
+
+  public static int getEncodedAttemptNumber(TaskContext context) {
+    return (context.stageAttemptNumber() << 16) | context.attemptNumber();
+  }
+}
diff --git 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
index 407b32849..06d7ccc72 100644
--- 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
+++ 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
@@ -68,6 +68,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private final ShuffleWriteMetrics writeMetrics;
   private final int shuffleId;
   private final int mapId;
+  private final int encodedAttemptId;
   private final TaskContext taskContext;
   private final ShuffleClient shuffleClient;
   private final int numMappers;
@@ -112,6 +113,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.mapId = mapId;
     this.dep = handle.dependency();
     this.shuffleId = shuffleId;
+    this.encodedAttemptId = 
SparkCommonUtils.getEncodedAttemptNumber(taskContext);
     SerializerInstance serializer = dep.serializer().newInstance();
     this.partitioner = dep.partitioner();
     this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
@@ -146,7 +148,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
           new DataPusher(
               shuffleId,
               mapId,
-              taskContext.attemptNumber(),
+              encodedAttemptId,
               taskContext.taskAttemptId(),
               numMappers,
               numPartitions,
@@ -273,7 +275,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         shuffleClient.pushData(
             shuffleId,
             mapId,
-            taskContext.attemptNumber(),
+            encodedAttemptId,
             partitionId,
             buffer,
             0,
@@ -318,7 +320,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     // here we wait for all the in-flight batches to return which sent by 
dataPusher thread
     dataPusher.waitOnTermination();
     sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
-    shuffleClient.prepareForMergeData(shuffleId, mapId, 
taskContext.attemptNumber());
+    shuffleClient.prepareForMergeData(shuffleId, mapId, encodedAttemptId);
 
     // merge and push residual data to reduce network traffic
     // NB: since dataPusher thread have no in-flight data at this point,
@@ -330,7 +332,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
             shuffleClient.mergeData(
                 shuffleId,
                 mapId,
-                taskContext.attemptNumber(),
+                encodedAttemptId,
                 i,
                 sendBuffers[i],
                 0,
@@ -343,7 +345,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         writeMetrics.incBytesWritten(bytesWritten);
       }
     }
-    shuffleClient.pushMergedData(shuffleId, mapId, 
taskContext.attemptNumber());
+    shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
 
     updateMapStatus();
 
@@ -352,7 +354,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     sendOffsets = null;
 
     long waitStartTime = System.nanoTime();
-    shuffleClient.mapperEnd(shuffleId, mapId, taskContext.attemptNumber(), 
numMappers);
+    shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
     writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
 
     BlockManagerId bmId = SparkEnv.get().blockManager().shuffleServerId();
@@ -389,7 +391,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         }
       }
     } finally {
-      shuffleClient.cleanup(shuffleId, mapId, taskContext.attemptNumber());
+      shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
     }
   }
 }
diff --git 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
index 0985c9b69..a8bd23c21 100644
--- 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
+++ 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
@@ -62,6 +62,7 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private final ShuffleWriteMetrics writeMetrics;
   private final int shuffleId;
   private final int mapId;
+  private final int encodedAttemptId;
   private final TaskContext taskContext;
   private final ShuffleClient shuffleClient;
   private final int numMappers;
@@ -102,6 +103,7 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.mapId = taskContext.partitionId();
     this.dep = dep;
     this.shuffleId = shuffleId;
+    this.encodedAttemptId = 
SparkCommonUtils.getEncodedAttemptNumber(taskContext);
     SerializerInstance serializer = dep.serializer().newInstance();
     this.partitioner = dep.partitioner();
     this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
@@ -130,7 +132,7 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
             taskContext,
             shuffleId,
             mapId,
-            taskContext.attemptNumber(),
+            encodedAttemptId,
             taskContext.taskAttemptId(),
             numMappers,
             numPartitions,
@@ -277,7 +279,7 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         shuffleClient.pushData(
             shuffleId,
             mapId,
-            taskContext.attemptNumber(),
+            encodedAttemptId,
             partitionId,
             buffer,
             0,
@@ -295,12 +297,12 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     pusher.close();
     writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);
 
-    shuffleClient.pushMergedData(shuffleId, mapId, 
taskContext.attemptNumber());
+    shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
 
     updateMapStatus();
 
     long waitStartTime = System.nanoTime();
-    shuffleClient.mapperEnd(shuffleId, mapId, taskContext.attemptNumber(), 
numMappers);
+    shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
     writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
   }
 
@@ -336,7 +338,7 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     } catch (IOException e) {
       return Option.apply(null);
     } finally {
-      shuffleClient.cleanup(shuffleId, mapId, taskContext.attemptNumber());
+      shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
     }
   }
 }
diff --git 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index 157fc2d0b..cabecfb15 100644
--- 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++ 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -64,6 +64,7 @@ public class SparkShuffleManager implements ShuffleManager {
   private ExecutorShuffleIdTracker shuffleIdTracker = new 
ExecutorShuffleIdTracker();
 
   public SparkShuffleManager(SparkConf conf, boolean isDriver) {
+    SparkCommonUtils.validateAttemptConfig(conf);
     this.conf = conf;
     this.isDriver = isDriver;
     this.celebornConf = SparkUtils.fromSparkConf(conf);
diff --git 
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
 
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index b74db710c..056b94c94 100644
--- 
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++ 
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -56,6 +56,7 @@ class CelebornShuffleReader[K, C](
     handle.extension)
 
   private val exceptionRef = new AtomicReference[IOException]
+  private val encodedAttemptId = 
SparkCommonUtils.getEncodedAttemptNumber(context)
 
   override def read(): Iterator[Product2[K, C]] = {
 
@@ -96,7 +97,7 @@ class CelebornShuffleReader[K, C](
               val inputStream = shuffleClient.readPartition(
                 shuffleId,
                 partitionId,
-                context.attemptNumber(),
+                encodedAttemptId,
                 startMapIndex,
                 endMapIndex,
                 metricsCallback)
diff --git 
a/client-spark/spark-2/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
 
b/client-spark/spark-2/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
index cb6c3e0ab..0786904b8 100644
--- 
a/client-spark/spark-2/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
+++ 
b/client-spark/spark-2/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.shuffle.celeborn
 import org.apache.spark.{SparkConf, SparkContext}
 import org.apache.spark.sql.internal.SQLConf
 import org.junit
+import org.junit.Assert
 import org.junit.runner.RunWith
 import org.junit.runners.JUnit4
 
@@ -67,4 +68,33 @@ class SparkShuffleManagerSuite extends Logging {
     sc.stop()
   }
 
+  @junit.Test
+  def testWrongSparkConfMaxAttemptLimit(): Unit = {
+    val conf = new SparkConf().setIfMissing("spark.master", "local")
+      .setIfMissing(
+        "spark.shuffle.manager",
+        "org.apache.spark.shuffle.celeborn.SparkShuffleManager")
+      .set(s"spark.${CelebornConf.MASTER_ENDPOINTS.key}", "localhost:9097")
+      .set(s"spark.${CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key}", "false")
+      .set("spark.shuffle.service.enabled", "false")
+      .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
+
+    // default conf, will success
+    new SparkShuffleManager(conf, true)
+
+    conf
+      .set("spark.stage.maxConsecutiveAttempts", "32768")
+      .set("spark.task.maxFailures", "10")
+    try {
+      new SparkShuffleManager(conf, true)
+      Assert.fail()
+    } catch {
+      case e: IllegalArgumentException =>
+        Assert.assertTrue(
+          e.getMessage.contains("The spark.stage.maxConsecutiveAttempts should 
be less than 32768"))
+      case _: Throwable =>
+        Assert.fail()
+    }
+  }
+
 }
diff --git 
a/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
 
b/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
index 5df434f54..158cb9406 100644
--- 
a/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
+++ 
b/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.shuffle.celeborn
 
-import org.apache.spark.{ShuffleDependency, SparkConf}
+import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext}
 import org.apache.spark.serializer.{KryoSerializer, KryoSerializerInstance}
 import org.apache.spark.sql.execution.UnsafeRowSerializer
 import 
org.apache.spark.sql.execution.columnar.CelebornColumnarBatchSerializerInstance
@@ -45,6 +45,9 @@ class CelebornColumnarShuffleReaderSuite {
 
     var shuffleClient: MockedStatic[ShuffleClient] = null
     try {
+      val taskContext = Mockito.mock(classOf[TaskContext])
+      Mockito.when(taskContext.stageAttemptNumber).thenReturn(0)
+      Mockito.when(taskContext.attemptNumber).thenReturn(0)
       shuffleClient = Mockito.mockStatic(classOf[ShuffleClient])
       val shuffleReader = SparkUtils.createColumnarShuffleReader(
         handle,
@@ -52,7 +55,7 @@ class CelebornColumnarShuffleReaderSuite {
         10,
         0,
         10,
-        null,
+        taskContext,
         new CelebornConf(),
         null,
         new ExecutorShuffleIdTracker())
@@ -68,6 +71,9 @@ class CelebornColumnarShuffleReaderSuite {
   def columnarShuffleReaderNewSerializerInstance(): Unit = {
     var shuffleClient: MockedStatic[ShuffleClient] = null
     try {
+      val taskContext = Mockito.mock(classOf[TaskContext])
+      Mockito.when(taskContext.stageAttemptNumber).thenReturn(0)
+      Mockito.when(taskContext.attemptNumber).thenReturn(0)
       shuffleClient = Mockito.mockStatic(classOf[ShuffleClient])
       val shuffleReader = SparkUtils.createColumnarShuffleReader(
         new CelebornShuffleHandle[Int, String, String](
@@ -83,7 +89,7 @@ class CelebornColumnarShuffleReaderSuite {
         10,
         0,
         10,
-        null,
+        taskContext,
         new CelebornConf(),
         null,
         new ExecutorShuffleIdTracker())
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
index a3024bf00..c3808b6f1 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
@@ -68,6 +68,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private final ShuffleWriteMetricsReporter writeMetrics;
   private final int shuffleId;
   private final int mapId;
+  private final int encodedAttemptId;
   private final TaskContext taskContext;
   private final ShuffleClient shuffleClient;
   private final int numMappers;
@@ -112,6 +113,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.mapId = taskContext.partitionId();
     this.dep = handle.dependency();
     this.shuffleId = shuffleId;
+    this.encodedAttemptId = 
SparkCommonUtils.getEncodedAttemptNumber(taskContext);
     SerializerInstance serializer = dep.serializer().newInstance();
     this.partitioner = dep.partitioner();
     this.writeMetrics = metrics;
@@ -143,7 +145,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
           new DataPusher(
               shuffleId,
               mapId,
-              taskContext.attemptNumber(),
+              encodedAttemptId,
               taskContext.taskAttemptId(),
               numMappers,
               numPartitions,
@@ -278,7 +280,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         shuffleClient.pushData(
             shuffleId,
             mapId,
-            taskContext.attemptNumber(),
+            encodedAttemptId,
             partitionId,
             buffer,
             0,
@@ -342,7 +344,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         shuffleClient.mergeData(
             shuffleId,
             mapId,
-            taskContext.attemptNumber(),
+            encodedAttemptId,
             partitionId,
             buffer,
             offset,
@@ -358,15 +360,15 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     long pushMergedDataTime = System.nanoTime();
     dataPusher.waitOnTermination();
     sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
-    shuffleClient.prepareForMergeData(shuffleId, mapId, 
taskContext.attemptNumber());
+    shuffleClient.prepareForMergeData(shuffleId, mapId, encodedAttemptId);
     closeWrite();
-    shuffleClient.pushMergedData(shuffleId, mapId, 
taskContext.attemptNumber());
+    shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
     writeMetrics.incWriteTime(System.nanoTime() - pushMergedDataTime);
 
     updateMapStatus();
 
     long waitStartTime = System.nanoTime();
-    shuffleClient.mapperEnd(shuffleId, mapId, taskContext.attemptNumber(), 
numMappers);
+    shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
     writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
 
     BlockManagerId bmId = SparkEnv.get().blockManager().shuffleServerId();
@@ -403,7 +405,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         }
       }
     } finally {
-      shuffleClient.cleanup(shuffleId, mapId, taskContext.attemptNumber());
+      shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
     }
   }
 
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
index cbb8ec725..7984f9ec8 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
@@ -62,6 +62,7 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private final ShuffleWriteMetricsReporter writeMetrics;
   private final int shuffleId;
   private final int mapId;
+  private final int encodedAttemptId;
   private final TaskContext taskContext;
   private final ShuffleClient shuffleClient;
   private final int numMappers;
@@ -102,6 +103,7 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.mapId = taskContext.partitionId();
     this.dep = dep;
     this.shuffleId = shuffleId;
+    this.encodedAttemptId = 
SparkCommonUtils.getEncodedAttemptNumber(taskContext);
     SerializerInstance serializer = dep.serializer().newInstance();
     this.partitioner = dep.partitioner();
     this.writeMetrics = metrics;
@@ -129,7 +131,7 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
             taskContext,
             shuffleId,
             mapId,
-            taskContext.attemptNumber(),
+            encodedAttemptId,
             taskContext.taskAttemptId(),
             numMappers,
             numPartitions,
@@ -298,7 +300,7 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         shuffleClient.pushData(
             shuffleId,
             mapId,
-            taskContext.attemptNumber(),
+            encodedAttemptId,
             partitionId,
             buffer,
             0,
@@ -315,13 +317,13 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     pusher.pushData();
     pusher.close();
 
-    shuffleClient.pushMergedData(shuffleId, mapId, 
taskContext.attemptNumber());
+    shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
     writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);
 
     updateMapStatus();
 
     long waitStartTime = System.nanoTime();
-    shuffleClient.mapperEnd(shuffleId, mapId, taskContext.attemptNumber(), 
numMappers);
+    shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
     writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
   }
 
@@ -355,7 +357,7 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         }
       }
     } finally {
-      shuffleClient.cleanup(shuffleId, mapId, taskContext.attemptNumber());
+      shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
     }
   }
 
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index d97e774bf..35cc39846 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -94,6 +94,7 @@ public class SparkShuffleManager implements ShuffleManager {
               + "use Celeborn as Remote Shuffle Service to avoid performance 
degradation.",
           SQLConf.LOCAL_SHUFFLE_READER_ENABLED().key());
     }
+    SparkCommonUtils.validateAttemptConfig(conf);
     this.conf = conf;
     this.isDriver = isDriver;
     this.celebornConf = SparkUtils.fromSparkConf(conf);
diff --git 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index 32c618bc2..9ba26116e 100644
--- 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++ 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -58,6 +58,7 @@ class CelebornShuffleReader[K, C](
 
   private val exceptionRef = new AtomicReference[IOException]
   private val throwsFetchFailure = handle.throwsFetchFailure
+  private val encodedAttemptId = 
SparkCommonUtils.getEncodedAttemptNumber(context)
 
   override def read(): Iterator[Product2[K, C]] = {
 
@@ -117,7 +118,7 @@ class CelebornShuffleReader[K, C](
                 shuffleId,
                 handle.shuffleId,
                 partitionId,
-                context.attemptNumber(),
+                encodedAttemptId,
                 startMapIndex,
                 endMapIndex,
                 if (throwsFetchFailure) exceptionMaker else null,
diff --git 
a/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
 
b/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
index 019d3f48e..21aa1074c 100644
--- 
a/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
+++ 
b/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.shuffle.celeborn
 import org.apache.spark.{SparkConf, SparkContext}
 import org.apache.spark.sql.internal.SQLConf
 import org.junit
+import org.junit.Assert
 import org.junit.runner.RunWith
 import org.junit.runners.JUnit4
 
@@ -67,4 +68,33 @@ class SparkShuffleManagerSuite extends Logging {
     // scalastyle:on println
     sc.stop()
   }
+
+  @junit.Test
+  def testWrongSparkConfMaxAttemptLimit(): Unit = {
+    val conf = new SparkConf().setIfMissing("spark.master", "local")
+      .setIfMissing(
+        "spark.shuffle.manager",
+        "org.apache.spark.shuffle.celeborn.SparkShuffleManager")
+      .set(s"spark.${CelebornConf.MASTER_ENDPOINTS.key}", "localhost:9097")
+      .set(s"spark.${CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key}", "false")
+      .set("spark.shuffle.service.enabled", "false")
+      .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
+
+    // default conf, will success
+    new SparkShuffleManager(conf, true)
+
+    conf
+      .set("spark.stage.maxConsecutiveAttempts", "32768")
+      .set("spark.task.maxFailures", "10")
+    try {
+      new SparkShuffleManager(conf, true)
+      Assert.fail()
+    } catch {
+      case e: IllegalArgumentException =>
+        Assert.assertTrue(
+          e.getMessage.contains("The spark.stage.maxConsecutiveAttempts should 
be less than 32768"))
+      case _: Throwable =>
+        Assert.fail()
+    }
+  }
 }

Reply via email to