This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch branch-0.5
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/branch-0.5 by this push:
     new 321b4e3ca [CELEBORN-1071] Support stage rerun for shuffle data lost
321b4e3ca is described below

commit 321b4e3ca15aa12239a22f614b144cbc1d8ac291
Author: mingji <[email protected]>
AuthorDate: Tue Nov 12 10:07:26 2024 +0800

    [CELEBORN-1071] Support stage rerun for shuffle data lost
    
    ### What changes were proposed in this pull request?
    If shuffle data is lost and enabled throw fetch failures, triggered stage 
rerun.
    
    ### Why are the changes needed?
    Rerun stage for shuffle lost scenarios.
    
    ### Does this PR introduce _any_ user-facing change?
    NO.
    
    ### How was this patch tested?
    GA.
    
    Closes #2894 from FMX/b1701.
    
    Authored-by: mingji <[email protected]>
    Signed-off-by: Shuang <[email protected]>
    (cherry picked from commit 42d5d426a1382162cb75a93aea48bcdef5389233)
    Signed-off-by: Shuang <[email protected]>
---
 .../shuffle/celeborn/CelebornShuffleReader.scala   | 55 +++++++++--------
 .../apache/celeborn/client/ShuffleClientImpl.java  |  2 +-
 .../commit/ReducePartitionCommitHandler.scala      |  8 ++-
 .../org/apache/celeborn/common/CelebornConf.scala  | 22 +++++++
 .../tests/spark/CelebornShuffleLostSuite.scala     | 71 ++++++++++++++++++++++
 5 files changed, 130 insertions(+), 28 deletions(-)

diff --git 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index ba57a44b9..0b5bfb401 100644
--- 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++ 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -34,6 +34,7 @@ import org.apache.spark.util.CompletionIterator
 import org.apache.spark.util.collection.ExternalSorter
 
 import org.apache.celeborn.client.ShuffleClient
+import org.apache.celeborn.client.ShuffleClientImpl.ReduceFileGroups
 import org.apache.celeborn.client.read.{CelebornInputStream, MetricsCallback}
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.exception.{CelebornIOException, 
PartitionUnRetryAbleException}
@@ -104,8 +105,16 @@ class CelebornShuffleReader[K, C](
     val localFetchEnabled = conf.enableReadLocalShuffleFile
     val localHostAddress = Utils.localHostName(conf)
     val shuffleKey = Utils.makeShuffleKey(handle.appUniqueId, shuffleId)
-    // startPartition is irrelevant
-    val fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition)
+    var fileGroups: ReduceFileGroups = null
+    try {
+      // startPartition is irrelevant
+      fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition)
+    } catch {
+      case ce @ (_: CelebornIOException | _: PartitionUnRetryAbleException) =>
+        handleFetchExceptions(shuffleId, 0, ce)
+      case e: Throwable => throw e
+    }
+
     // host-port -> (TransportClient, PartitionLocation Array, 
PbOpenStreamList)
     val workerRequestMap = new util.HashMap[
       String,
@@ -245,18 +254,7 @@ class CelebornShuffleReader[K, C](
           if (exceptionRef.get() != null) {
             exceptionRef.get() match {
               case ce @ (_: CelebornIOException | _: 
PartitionUnRetryAbleException) =>
-                if (throwsFetchFailure &&
-                  shuffleClient.reportShuffleFetchFailure(handle.shuffleId, 
shuffleId)) {
-                  throw new FetchFailedException(
-                    null,
-                    handle.shuffleId,
-                    -1,
-                    -1,
-                    partitionId,
-                    SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId + 
"/" + shuffleId,
-                    ce)
-                } else
-                  throw ce
+                handleFetchExceptions(handle.shuffleId, partitionId, ce)
               case e => throw e
             }
           }
@@ -289,18 +287,7 @@ class CelebornShuffleReader[K, C](
         iter
       } catch {
         case e @ (_: CelebornIOException | _: PartitionUnRetryAbleException) =>
-          if (throwsFetchFailure &&
-            shuffleClient.reportShuffleFetchFailure(handle.shuffleId, 
shuffleId)) {
-            throw new FetchFailedException(
-              null,
-              handle.shuffleId,
-              -1,
-              -1,
-              partitionId,
-              SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId + "/" + 
shuffleId,
-              e)
-          } else
-            throw e
+          handleFetchExceptions(handle.shuffleId, partitionId, e)
       }
     }
 
@@ -380,6 +367,22 @@ class CelebornShuffleReader[K, C](
     }
   }
 
+  private def handleFetchExceptions(shuffleId: Int, partitionId: Int, ce: 
Throwable) = {
+    if (throwsFetchFailure &&
+      shuffleClient.reportShuffleFetchFailure(handle.shuffleId, shuffleId)) {
+      logWarning(s"Handle fetch exceptions for ${shuffleId}-${partitionId}", 
ce)
+      throw new FetchFailedException(
+        null,
+        handle.shuffleId,
+        -1,
+        -1,
+        partitionId,
+        SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId + "/" + 
shuffleId,
+        ce)
+    } else
+      throw ce
+  }
+
   def newSerializerInstance(dep: ShuffleDependency[K, _, C]): 
SerializerInstance = {
     dep.serializer.newInstance()
   }
diff --git 
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index 7489d4f49..00018d95d 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -138,7 +138,7 @@ public class ShuffleClientImpl extends ShuffleClient {
 
   private final ReviveManager reviveManager;
 
-  protected static class ReduceFileGroups {
+  public static class ReduceFileGroups {
     public Map<Integer, Set<PartitionLocation>> partitionGroups;
     public int[] mapAttempts;
     public Set<Integer> partitionIds;
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
 
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
index 23d6a7b8d..1b1d8be39 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
@@ -68,6 +68,8 @@ class ReducePartitionCommitHandler(
   private val inProcessStageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]()
   private val shuffleMapperAttempts = JavaUtils.newConcurrentHashMap[Int, 
Array[Int]]()
   private val stageEndTimeout = conf.clientPushStageEndTimeout
+  private val mockShuffleLost = conf.testMockShuffleLost
+  private val mockShuffleLostShuffle = conf.testMockShuffleLostShuffle
 
   private val rpcCacheSize = conf.clientRpcCacheSize
   private val rpcCacheConcurrencyLevel = conf.clientRpcCacheConcurrencyLevel
@@ -94,7 +96,11 @@ class ReducePartitionCommitHandler(
   }
 
   override def isStageDataLost(shuffleId: Int): Boolean = {
-    dataLostShuffleSet.contains(shuffleId)
+    if (mockShuffleLost) {
+      mockShuffleLostShuffle == shuffleId
+    } else {
+      dataLostShuffleSet.contains(shuffleId)
+    }
   }
 
   override def isPartitionInProcess(shuffleId: Int, partitionId: Int): Boolean 
= {
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala 
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index 9fe6ebb0e..dc0fd4e86 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -1245,6 +1245,8 @@ class CelebornConf(loadDefaults: Boolean) extends 
Cloneable with Logging with Se
   def testFetchFailure: Boolean = get(TEST_CLIENT_FETCH_FAILURE)
   def testMockDestroySlotsFailure: Boolean = 
get(TEST_CLIENT_MOCK_DESTROY_SLOTS_FAILURE)
   def testMockCommitFilesFailure: Boolean = 
get(TEST_CLIENT_MOCK_COMMIT_FILES_FAILURE)
+  def testMockShuffleLost: Boolean = get(TEST_CLIENT_MOCK_SHUFFLE_LOST)
+  def testMockShuffleLostShuffle: Int = 
get(TEST_CLIENT_MOCK_SHUFFLE_LOST_SHUFFLE)
   def testPushPrimaryDataTimeout: Boolean = 
get(TEST_CLIENT_PUSH_PRIMARY_DATA_TIMEOUT)
   def testPushReplicaDataTimeout: Boolean = 
get(TEST_WORKER_PUSH_REPLICA_DATA_TIMEOUT)
   def testRetryRevive: Boolean = get(TEST_CLIENT_RETRY_REVIVE)
@@ -3716,6 +3718,26 @@ object CelebornConf extends Logging {
       .booleanConf
       .createWithDefault(false)
 
+  val TEST_CLIENT_MOCK_SHUFFLE_LOST: ConfigEntry[Boolean] =
+    buildConf("celeborn.test.client.mockShuffleLost")
+      .internal
+      .categories("test", "client")
+      .doc("Mock shuffle lost.")
+      .version("0.5.2")
+      .internal
+      .booleanConf
+      .createWithDefault(false)
+
+  val TEST_CLIENT_MOCK_SHUFFLE_LOST_SHUFFLE: ConfigEntry[Int] =
+    buildConf("celeborn.test.client.mockShuffleLostShuffle")
+      .internal
+      .categories("test", "client")
+      .doc("Mock shuffle lost for shuffle")
+      .version("0.5.2")
+      .internal
+      .intConf
+      .createWithDefault(0)
+
   val CLIENT_PUSH_REPLICATE_ENABLED: ConfigEntry[Boolean] =
     buildConf("celeborn.client.push.replicate.enabled")
       .withAlternative("celeborn.push.replicate.enabled")
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleLostSuite.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleLostSuite.scala
new file mode 100644
index 000000000..8c0e8b101
--- /dev/null
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleLostSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.tests.spark
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.SparkSession
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.celeborn.client.ShuffleClient
+import org.apache.celeborn.common.protocol.ShuffleMode
+
+class CelebornShuffleLostSuite extends AnyFunSuite
+  with SparkTestBase
+  with BeforeAndAfterEach {
+
+  override def beforeEach(): Unit = {
+    ShuffleClient.reset()
+  }
+
+  override def afterEach(): Unit = {
+    System.gc()
+  }
+
+  test("celeborn shuffle data lost - hash") {
+    val sparkConf = new 
SparkConf().setAppName("celeborn-demo").setMaster("local[2]")
+    val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
+    val combineResult = combine(sparkSession)
+    val groupbyResult = groupBy(sparkSession)
+    val repartitionResult = repartition(sparkSession)
+    val sqlResult = runsql(sparkSession)
+
+    Thread.sleep(3000L)
+    sparkSession.stop()
+
+    val conf = updateSparkConf(sparkConf, ShuffleMode.HASH)
+    conf.set("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true")
+    conf.set("spark.celeborn.test.client.mockShuffleLost", "true")
+
+    val celebornSparkSession = SparkSession.builder()
+      .config(conf)
+      .getOrCreate()
+    val celebornCombineResult = combine(celebornSparkSession)
+    val celebornGroupbyResult = groupBy(celebornSparkSession)
+    val celebornRepartitionResult = repartition(celebornSparkSession)
+    val celebornSqlResult = runsql(celebornSparkSession)
+
+    assert(combineResult.equals(celebornCombineResult))
+    assert(groupbyResult.equals(celebornGroupbyResult))
+    assert(repartitionResult.equals(celebornRepartitionResult))
+    assert(combineResult.equals(celebornCombineResult))
+    assert(sqlResult.equals(celebornSqlResult))
+
+    celebornSparkSession.stop()
+  }
+}

Reply via email to