This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 0b892a543f9 [SPARK-40932][CORE] Fix issue messages for allGather are overridden 0b892a543f9 is described below commit 0b892a543f9ea913f961eea95a4e45f1231b9a57 Author: Bobby Wang <wbo4...@gmail.com> AuthorDate: Fri Oct 28 21:06:49 2022 +0800 [SPARK-40932][CORE] Fix issue messages for allGather are overridden ### What changes were proposed in this pull request? The messages returned by allGather may be overridden by the following barrier APIs, eg, ``` scala val messages: Array[String] = context.allGather("ABC") context.barrier() ``` the `messages` may be like Array("", ""), but we're expecting Array("ABC", "ABC") The root cause of this issue is the [messages got by allGather](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala#L102) pointing to the [original message](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala#L107) in the local mode. So when the following barrier APIs changed the messages, then the allGather message will be changed accordingly. Finally, users can't get the correct result. This PR fixed this issue by sending back the cloned messages. ### Why are the changes needed? The bug mentioned in this description may block some external SPARK ML libraries which heavily depend on the spark barrier API to do some synchronization. If the barrier mechanism can't guarantee the correctness of the barrier APIs, it will be a disaster for external SPARK ML libraries. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? I added a unit test, with this PR, the unit test can pass Closes #38410 from wbo4958/allgather-issue. Authored-by: Bobby Wang <wbo4...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../org/apache/spark/BarrierCoordinator.scala | 2 +- .../spark/scheduler/BarrierTaskContextSuite.scala | 23 ++++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 04faf7f87cf..8ffccdf664b 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -176,7 +176,7 @@ private[spark] class BarrierCoordinator( logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " + s"$taskId, current progress: ${requesters.size}/$numTasks.") if (requesters.size == numTasks) { - requesters.foreach(_.reply(messages)) + requesters.foreach(_.reply(messages.clone())) // Finished current barrier() call successfully, clean up ContextBarrierState and // increase the barrier epoch. logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received all updates from " + diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index 4f97003e2ed..26cd5374fa0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -367,4 +367,27 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with // double check we kill task success assert(System.currentTimeMillis() - startTime < 5000) } + + test("SPARK-40932, messages of allGather should not been overridden " + + "by the following barrier APIs") { + + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local[2]")) + sc.setLogLevel("INFO") + val rdd = sc.makeRDD(1 to 10, 2) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + // Pass partitionId message in + val message: String = context.partitionId().toString + val messages: Array[String] = context.allGather(message) + context.barrier() + Iterator.single(messages.toList) + } + val messages = rdd2.collect() + // All the task partitionIds are shared across all tasks + assert(messages.length === 2) + assert(messages.forall(_ == List("0", "1"))) + } + } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org