This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch branch-0.4
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/branch-0.4 by this push:
     new 93aec297b [CELEBORN-1271] Fix unregisterShuffle with 
celeborn.client.spark.fetch.throwsFetchFailure disabled
93aec297b is described below

commit 93aec297b8142044a48e47e4841467656b44a449
Author: Erik.fang <[email protected]>
AuthorDate: Thu Feb 29 16:17:54 2024 +0800

    [CELEBORN-1271] Fix unregisterShuffle with 
celeborn.client.spark.fetch.throwsFetchFailure disabled
    
    ### What changes were proposed in this pull request?
    per https://issues.apache.org/jira/browse/CELEBORN-1271
    fix the bug with SparkShuffleManager.unregisterShuffle when 
celeborn.client.spark.fetch.throwsFetchFailure=false
    
    ### Why are the changes needed?
    the bug causes shuffle data can't be cleaned with unregisterShuffle
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    manual tested
    
    Closes #2305 from ErikFang/CELEBORN-1271-fix-unregisterShuffle.
    
    Authored-by: Erik.fang <[email protected]>
    Signed-off-by: waitinfuture <[email protected]>
---
 .../shuffle/celeborn/SparkShuffleManager.java      |  3 +-
 .../shuffle/celeborn/SparkShuffleManager.java      |  3 +-
 .../apache/celeborn/client/LifecycleManager.scala  | 24 ++++++++----
 .../tests/spark/CelebornFetchFailureSuite.scala    | 44 +++++++++++++++++++++-
 4 files changed, 62 insertions(+), 12 deletions(-)

diff --git 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index cda992b29..0344ca6d3 100644
--- 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++ 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -142,7 +142,8 @@ public class SparkShuffleManager implements ShuffleManager {
     }
     // For Spark driver side trigger unregister shuffle.
     if (lifecycleManager != null) {
-      lifecycleManager.unregisterAppShuffle(appShuffleId);
+      lifecycleManager.unregisterAppShuffle(
+          appShuffleId, celebornConf.clientFetchThrowsFetchFailure());
     }
     // For Spark executor side cleanup shuffle related info.
     if (shuffleClient != null) {
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index a9b714fb6..d97e774bf 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -183,7 +183,8 @@ public class SparkShuffleManager implements ShuffleManager {
     }
     // For Spark driver side trigger unregister shuffle.
     if (lifecycleManager != null) {
-      lifecycleManager.unregisterAppShuffle(appShuffleId);
+      lifecycleManager.unregisterAppShuffle(
+          appShuffleId, celebornConf.clientFetchThrowsFetchFailure());
     }
     // For Spark executor side cleanup shuffle related info.
     if (shuffleClient != null) {
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 5308ed30b..e73106d21 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -113,6 +113,10 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
   def workerSnapshots(shuffleId: Int): util.Map[WorkerInfo, 
ShufflePartitionLocationInfo] =
     shuffleAllocatedWorkers.get(shuffleId)
 
+  @VisibleForTesting
+  def getUnregisterShuffleTime(): ConcurrentHashMap[Int, Long] =
+    unregisterShuffleTime
+
   val newMapFunc: function.Function[Int, ConcurrentHashMap[Int, 
PartitionLocation]] =
     new util.function.Function[Int, ConcurrentHashMap[Int, 
PartitionLocation]]() {
       override def apply(s: Int): ConcurrentHashMap[Int, PartitionLocation] = {
@@ -908,16 +912,20 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
     logInfo(s"Unregister for $shuffleId success.")
   }
 
-  def unregisterAppShuffle(appShuffleId: Int): Unit = {
+  def unregisterAppShuffle(appShuffleId: Int, hasMapping: Boolean): Unit = {
     logInfo(s"Unregister appShuffleId $appShuffleId starts...")
     appShuffleDeterminateMap.remove(appShuffleId)
-    val shuffleIds = shuffleIdMapping.remove(appShuffleId)
-    if (shuffleIds != null) {
-      shuffleIds.synchronized(
-        shuffleIds.values.map {
-          case (shuffleId, _) =>
-            unregisterShuffle(shuffleId)
-        })
+    if (hasMapping) {
+      val shuffleIds = shuffleIdMapping.remove(appShuffleId)
+      if (shuffleIds != null) {
+        shuffleIds.synchronized(
+          shuffleIds.values.map {
+            case (shuffleId, _) =>
+              unregisterShuffle(shuffleId)
+          })
+      }
+    } else {
+      unregisterShuffle(appShuffleId)
     }
   }
 
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala
index 8983f6bd6..291bdca18 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala
@@ -20,9 +20,9 @@ package org.apache.celeborn.tests.spark
 import java.io.File
 import java.util.concurrent.atomic.AtomicBoolean
 
-import org.apache.spark.{SparkConf, TaskContext}
+import org.apache.spark.{SparkConf, SparkContextHelper, TaskContext}
 import org.apache.spark.shuffle.ShuffleHandle
-import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, 
ShuffleManagerHook, SparkUtils, TestCelebornShuffleManager}
+import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, 
ShuffleManagerHook, SparkShuffleManager, SparkUtils, TestCelebornShuffleManager}
 import org.apache.spark.sql.SparkSession
 import org.scalatest.BeforeAndAfterEach
 import org.scalatest.funsuite.AnyFunSuite
@@ -120,6 +120,46 @@ class CelebornFetchFailureSuite extends AnyFunSuite
       assert(elem._2.mkString(",").equals(value))
     }
 
+    val shuffleMgr = SparkContextHelper.env
+      .shuffleManager
+      .asInstanceOf[TestCelebornShuffleManager]
+    val lifecycleManager = shuffleMgr.getLifecycleManager
+
+    shuffleMgr.unregisterShuffle(0)
+    assert(lifecycleManager.getUnregisterShuffleTime().containsKey(0))
+    assert(lifecycleManager.getUnregisterShuffleTime().containsKey(1))
+
+    sparkSession.stop()
+  }
+
+  test("celeborn spark integration test - unregister shuffle with 
throwsFetchFailure disabled") {
+    val sparkConf = new 
SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
+    val sparkSession = SparkSession.builder()
+      .config(updateSparkConf(sparkConf, ShuffleMode.HASH))
+      .config("spark.sql.shuffle.partitions", 2)
+      .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false)
+      .config("spark.celeborn.shuffle.enabled", "true")
+      .config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "false")
+      .getOrCreate()
+
+    val value = Range(1, 10000).mkString(",")
+    val tuples = sparkSession.sparkContext.parallelize(1 to 10000, 2)
+      .map { i => (i, value) }.groupByKey(16).collect()
+
+    // verify result
+    assert(tuples.length == 10000)
+    for (elem <- tuples) {
+      assert(elem._2.mkString(",").equals(value))
+    }
+
+    val shuffleMgr = SparkContextHelper.env
+      .shuffleManager
+      .asInstanceOf[SparkShuffleManager]
+    val lifecycleManager = shuffleMgr.getLifecycleManager
+
+    shuffleMgr.unregisterShuffle(0)
+    assert(lifecycleManager.getUnregisterShuffleTime().containsKey(0))
+
     sparkSession.stop()
   }
 

Reply via email to