This is an automated email from the ASF dual-hosted git repository.
zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new f8bb2cd4 [CELEBORN-12]Retry on CommitFile request (#1011)
f8bb2cd4 is described below
commit f8bb2cd47d5271583f0a343b1173b24d63bb7669
Author: Keyong Zhou <[email protected]>
AuthorDate: Sat Nov 26 20:56:24 2022 +0800
[CELEBORN-12]Retry on CommitFile request (#1011)
---
.../apache/celeborn/client/LifecycleManager.scala | 42 +++++++----
.../common/protocol/message/StatusCode.java | 4 +-
.../org/apache/celeborn/common/CelebornConf.scala | 28 ++++++-
docs/configuration/client.md | 2 +
docs/configuration/worker.md | 2 +-
.../apache/celeborn/tests/spark/HugeDataTest.scala | 2 +-
...geDataTest.scala => RetryCommitFilesTest.scala} | 16 ++--
.../apache/celeborn/tests/spark/RssHashSuite.scala | 2 +-
.../apache/celeborn/tests/spark/RssSortSuite.scala | 2 +-
.../celeborn/tests/spark/SkewJoinSuite.scala | 2 +-
.../celeborn/tests/spark/SparkTestBase.scala | 12 +--
.../service/deploy/worker/storage/FileWriter.java | 5 ++
.../service/deploy/worker/CommitInfo.scala | 28 +++++++
.../service/deploy/worker/Controller.scala | 85 +++++++++++++++++++---
.../celeborn/service/deploy/worker/Worker.scala | 3 +
15 files changed, 190 insertions(+), 45 deletions(-)
diff --git
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 917f908b..c28621bf 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -85,6 +85,8 @@ class LifecycleManager(appId: String, val conf: CelebornConf)
extends RpcEndpoin
.maximumSize(rpcCacheSize)
.build().asInstanceOf[Cache[Int, ByteBuffer]]
+ private val testCommitFileFailure = conf.testRetryCommitFiles
+
@VisibleForTesting
def workerSnapshots(shuffleId: Int): util.Map[WorkerInfo,
PartitionLocationInfo] =
shuffleAllocatedWorkers.get(shuffleId)
@@ -992,7 +994,7 @@ class LifecycleManager(appId: String, val conf:
CelebornConf) extends RpcEndpoin
masterIds,
slaveIds,
shuffleMapperAttempts.get(shuffleId))
- val res = requestCommitFiles(worker.endpoint, commitFiles)
+ val res = requestCommitFilesWithRetry(worker.endpoint, commitFiles)
res.status match {
case StatusCode.SUCCESS => // do nothing
@@ -1638,21 +1640,35 @@ class LifecycleManager(appId: String, val conf:
CelebornConf) extends RpcEndpoin
}
}
- private def requestCommitFiles(
+ private def requestCommitFilesWithRetry(
endpoint: RpcEndpointRef,
message: CommitFiles): CommitFilesResponse = {
- try {
- endpoint.askSync[CommitFilesResponse](message)
- } catch {
- case e: Exception =>
- logError(s"AskSync CommitFiles for ${message.shuffleId} failed.", e)
- CommitFilesResponse(
- StatusCode.FAILED,
- List.empty.asJava,
- List.empty.asJava,
- message.masterIds,
- message.slaveIds)
+ val maxRetries = conf.requestCommitFilesMaxRetries
+ var retryTimes = 0
+ while (retryTimes < maxRetries) {
+ try {
+ if (testCommitFileFailure && retryTimes < maxRetries - 1) {
+ endpoint.ask[CommitFilesResponse](message)
+ Thread.sleep(1000)
+ throw new Exception("Mock fail for CommitFiles")
+ } else {
+ return endpoint.askSync[CommitFilesResponse](message)
+ }
+ } catch {
+ case e: Exception =>
+ retryTimes += 1
+ logError(
+ s"AskSync CommitFiles for ${message.shuffleId} failed (attempt
$retryTimes/$maxRetries).",
+ e)
+ }
}
+
+ CommitFilesResponse(
+ StatusCode.FAILED,
+ List.empty.asJava,
+ List.empty.asJava,
+ message.masterIds,
+ message.slaveIds)
}
private def requestReleaseSlots(
diff --git
a/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
b/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
index ced17e90..8a507e43 100644
---
a/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
+++
b/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
@@ -53,7 +53,9 @@ public enum StatusCode {
WORKER_SHUTDOWN(25),
NO_AVAILABLE_WORKING_DIR(26),
WORKER_IN_BLACKLIST(27),
- UNKNOWN_WORKER(28);
+ UNKNOWN_WORKER(28),
+
+ COMMIT_FILE_EXCEPTION(29);
private final byte value;
diff --git
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index 16fcdd82..8881a9cb 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -531,6 +531,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable
with Logging with Se
def workerExcludedExpireTimeout: Long = get(WORKER_EXCLUDED_EXPIRE_TIMEOUT)
def shuffleRangeReadFilterEnabled: Boolean =
get(SHUFFLE_RANGE_READ_FILTER_ENABLED)
def shufflePartitionType: PartitionType =
PartitionType.valueOf(get(SHUFFLE_PARTITION_TYPE))
+ def requestCommitFilesMaxRetries: Int = get(COMMIT_FILE_REQUEST_MAX_RETRY)
// //////////////////////////////////////////////////////
// Shuffle Compression //
@@ -550,6 +551,12 @@ class CelebornConf(loadDefaults: Boolean) extends
Cloneable with Logging with Se
}
}
+ // //////////////////////////////////////////////////////
+ // test //
+ // //////////////////////////////////////////////////////
+ def testFetchFailure: Boolean = get(TEST_FETCH_FAILURE)
+ def testRetryCommitFiles: Boolean = get(TEST_RETRY_COMMIT_FILE)
+
def masterHost: String = get(MASTER_HOST)
def masterPort: Int = get(MASTER_PORT)
@@ -643,7 +650,6 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable
with Logging with Se
def fetchTimeoutMs: Long = get(FETCH_TIMEOUT)
def fetchMaxReqsInFlight: Int = get(FETCH_MAX_REQS_IN_FLIGHT)
def fetchMaxRetries: Int = get(FETCH_MAX_RETRIES)
- def testFetchFailure: Boolean = get(TEST_FETCH_FAILURE)
// //////////////////////////////////////////////////////
// Shuffle Client Push //
@@ -1364,6 +1370,23 @@ object CelebornConf extends Logging {
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("3s")
+ val COMMIT_FILE_REQUEST_MAX_RETRY: ConfigEntry[Int] =
+ buildConf("celeborn.rpc.requestCommitFiles.maxRetries")
+ .categories("client")
+ .doc("Max retry times for requestCommitFiles RPC.")
+ .version("1.0.0")
+ .intConf
+ .checkValue(v => v > 0, "value must be positive")
+ .createWithDefault(2)
+
+ val TEST_RETRY_COMMIT_FILE: ConfigEntry[Boolean] =
+ buildConf("celeborn.test.retryCommitFiles")
+ .categories("client")
+ .doc("Fail commitFile request for test")
+ .version("0.2.0")
+ .booleanConf
+ .createWithDefault(false)
+
val MASTER_HOST: ConfigEntry[String] =
buildConf("celeborn.master.host")
.categories("master")
@@ -1796,8 +1819,7 @@ object CelebornConf extends Logging {
.categories("worker")
.doc("Timeout for a Celeborn worker to commit files of a shuffle.")
.version("0.2.0")
- .timeConf(TimeUnit.SECONDS)
- .createWithDefaultString("120s")
+ .fallbackConf(RPC_ASK_TIMEOUT)
val PARTITION_SORTER_SORT_TIMEOUT: ConfigEntry[Long] =
buildConf("celeborn.worker.partitionSorter.sort.timeout")
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index 111452cb..3910f1d7 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -41,6 +41,7 @@ license: |
| celeborn.rpc.cache.expireTime | 15s | The time before a cache item is
removed. | 0.2.0 |
| celeborn.rpc.cache.size | 256 | The max cache items count for rpc cache. |
0.2.0 |
| celeborn.rpc.maxParallelism | 1024 | Max parallelism of client on sending
RPC requests. | 0.2.0 |
+| celeborn.rpc.requestCommitFiles.maxRetries | 2 | Max retry times for
requestCommitFiles RPC. | 1.0.0 |
| celeborn.shuffle.batchHandleChangePartition.enabled | false | When true,
LifecycleManager will handle change partition request in batch. Otherwise,
LifecycleManager will process the requests one by one | 0.2.0 |
| celeborn.shuffle.batchHandleChangePartition.interval | 100ms | Interval for
LifecycleManager to schedule handling change partition requests in batch. |
0.2.0 |
| celeborn.shuffle.batchHandleChangePartition.threads | 8 | Threads number for
LifecycleManager to handle change partition request in batch. | 0.2.0 |
@@ -62,6 +63,7 @@ license: |
| celeborn.slots.reserve.retryWait | 3s | Wait time before next retry if
reserve slots failed. | 0.2.0 |
| celeborn.storage.hdfs.dir | <undefined> | HDFS dir configuration for
Celeborn to access HDFS. | 0.2.0 |
| celeborn.test.fetchFailure | false | Wheter to test fetch chunk failure |
0.2.0 |
+| celeborn.test.retryCommitFiles | false | Fail commitFile request for test |
0.2.0 |
| celeborn.worker.excluded.checkInterval | 30s | Interval for client to
refresh excluded worker list. | 0.2.0 |
| celeborn.worker.excluded.expireTimeout | 600s | Timeout time for
LifecycleManager to clear reserved excluded worker. | 0.2.0 |
<!--end-include-->
diff --git a/docs/configuration/worker.md b/docs/configuration/worker.md
index 03881f93..bf1e48ac 100644
--- a/docs/configuration/worker.md
+++ b/docs/configuration/worker.md
@@ -73,7 +73,7 @@ license: |
| celeborn.worker.replicate.port | 0 | Server port for Worker to receive
replicate data request from other Workers. | 0.2.0 |
| celeborn.worker.replicate.threads | 64 | Thread number of worker to
replicate shuffle data. | 0.2.0 |
| celeborn.worker.rpc.port | 0 | Server port for Worker to receive RPC
request. | 0.2.0 |
-| celeborn.worker.shuffle.commit.timeout | 120s | Timeout for a Celeborn
worker to commit files of a shuffle. | 0.2.0 |
+| celeborn.worker.shuffle.commit.timeout | <value of
celeborn.rpc.askTimeout> | Timeout for a Celeborn worker to commit files of
a shuffle. | 0.2.0 |
| celeborn.worker.storage.baseDir.number | 16 | How many directories will be
used if `celeborn.worker.storage.dirs` is not set. The directory name is a
combination of `celeborn.worker.storage.baseDir.prefix` and from one(inclusive)
to `celeborn.worker.storage.baseDir.number`(inclusive) step by one. | 0.2.0 |
| celeborn.worker.storage.baseDir.prefix | /mnt/disk | Base directory for
Celeborn worker to write if `celeborn.worker.storage.dirs` is not set. | 0.2.0
|
| celeborn.worker.storage.dirs | <undefined> | Directory list to store
shuffle data. It's recommended to configure one directory on each disk. Storage
size limit can be set for each directory. For the sake of performance, there
should be no more than 2 flush threads on the same disk partition if you are
using HDD, and should be 8 or more flush threads on the same disk partition if
you are using SSD. For example:
`dir1[:capacity=][:disktype=][:flushthread=],dir2[:capacity=][:disktyp [...]
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/HugeDataTest.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/HugeDataTest.scala
index a45258d6..932998e7 100644
---
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/HugeDataTest.scala
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/HugeDataTest.scala
@@ -31,7 +31,7 @@ class HugeDataTest extends AnyFunSuite
override def beforeAll(): Unit = {
logInfo("test initialized , setup rss mini cluster")
- tuple = setupRssMiniCluster()
+ tuple = setupRssMiniClusterSpark()
}
override def afterAll(): Unit = {
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/HugeDataTest.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryCommitFilesTest.scala
similarity index 74%
copy from
tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/HugeDataTest.scala
copy to
tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryCommitFilesTest.scala
index a45258d6..afacdf5d 100644
---
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/HugeDataTest.scala
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryCommitFilesTest.scala
@@ -24,14 +24,16 @@ import org.scalatest.funsuite.AnyFunSuite
import org.apache.celeborn.client.ShuffleClient
-class HugeDataTest extends AnyFunSuite
+class RetryCommitFilesTest extends AnyFunSuite
with SparkTestBase
with BeforeAndAfterAll
with BeforeAndAfterEach {
override def beforeAll(): Unit = {
logInfo("test initialized , setup rss mini cluster")
- tuple = setupRssMiniCluster()
+ val workerConf = Map(
+ "celeborn.test.retryCommitFiles" -> s"true")
+ tuple = setupRssMiniClusterSpark(masterConfs = null, workerConfs =
workerConf)
}
override def afterAll(): Unit = {
@@ -47,11 +49,13 @@ class HugeDataTest extends AnyFunSuite
System.gc()
}
- test("celeborn spark integration test - huge data") {
- val sparkConf = new
SparkConf().setAppName("rss-demo").setMaster("local[4]")
+ test("celeborn spark integration test - retry commit files") {
+ val sparkConf = new SparkConf()
+ .set("spark.celeborn.test.retryCommitFiles", "true")
+ .setAppName("rss-demo").setMaster("local[4]")
val ss = SparkSession.builder().config(updateSparkConf(sparkConf,
false)).getOrCreate()
- ss.sparkContext.parallelize(1 to 10000, 2)
- .map { i => (i, Range(1, 10000).mkString(",")) }.groupByKey(16).collect()
+ ss.sparkContext.parallelize(1 to 1000, 2)
+ .map { i => (i, Range(1, 1000).mkString(",")) }.groupByKey(16).collect()
ss.stop()
}
}
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RssHashSuite.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RssHashSuite.scala
index 2146ae87..6e73c217 100644
---
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RssHashSuite.scala
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RssHashSuite.scala
@@ -31,7 +31,7 @@ class RssHashSuite extends AnyFunSuite
override def beforeAll(): Unit = {
logInfo("test initialized , setup rss mini cluster")
- tuple = setupRssMiniCluster()
+ tuple = setupRssMiniClusterSpark()
}
override def afterAll(): Unit = {
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RssSortSuite.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RssSortSuite.scala
index 89f53d25..ea9537f0 100644
---
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RssSortSuite.scala
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RssSortSuite.scala
@@ -31,7 +31,7 @@ class RssSortSuite extends AnyFunSuite
override def beforeAll(): Unit = {
logInfo("test initialized , setup rss mini cluster")
- tuple = setupRssMiniCluster()
+ tuple = setupRssMiniClusterSpark()
}
override def afterAll(): Unit = {
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala
index bf6e81e0..19e74005 100644
---
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala
@@ -35,7 +35,7 @@ class SkewJoinSuite extends AnyFunSuite
override def beforeAll(): Unit = {
logInfo("test initialized , setup rss mini cluster")
- tuple = setupRssMiniCluster()
+ tuple = setupRssMiniClusterSpark()
}
override def afterAll(): Unit = {
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
index 6e3aac2d..92d79d5f 100644
---
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
@@ -76,7 +76,9 @@ trait SparkTestBase extends Logging with MiniClusterFeature {
tuple._12.interrupt()
}
- def setupRssMiniCluster(): (
+ def setupRssMiniClusterSpark(
+ masterConfs: Map[String, String] = null,
+ workerConfs: Map[String, String] = null): (
Master,
RpcEnv,
Worker,
@@ -91,10 +93,10 @@ trait SparkTestBase extends Logging with MiniClusterFeature
{
Thread) = {
Thread.sleep(3000L)
- val (master, masterRpcEnv) = createMaster()
- val (worker1, workerRpcEnv1) = createWorker()
- val (worker2, workerRpcEnv2) = createWorker()
- val (worker3, workerRpcEnv3) = createWorker()
+ val (master, masterRpcEnv) = createMaster(masterConfs)
+ val (worker1, workerRpcEnv1) = createWorker(workerConfs)
+ val (worker2, workerRpcEnv2) = createWorker(workerConfs)
+ val (worker3, workerRpcEnv3) = createWorker(workerConfs)
val masterThread = runnerWrap(masterRpcEnv.awaitTermination())
val workerThread1 = runnerWrap(worker1.initialize())
val workerThread2 = runnerWrap(worker2.initialize())
diff --git
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/FileWriter.java
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/FileWriter.java
index 004a41c5..52e3975d 100644
---
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/FileWriter.java
+++
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/FileWriter.java
@@ -200,6 +200,11 @@ public final class FileWriter implements DeviceObserver {
final int numBytes = data.readableBytes();
MemoryTracker.instance().incrementDiskBuffer(numBytes);
synchronized (this) {
+ if (closed) {
+ String msg = "FileWriter has already closed!, fileName " +
fileInfo.getFilePath();
+ logger.warn(msg);
+ throw new AlreadyClosedException(msg);
+ }
if (rangeReadFilter) {
mapIdBitMap.add(mapId);
}
diff --git
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/CommitInfo.scala
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/CommitInfo.scala
new file mode 100644
index 00000000..4778d148
--- /dev/null
+++
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/CommitInfo.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.service.deploy.worker
+
+import
org.apache.celeborn.common.protocol.message.ControlMessages.CommitFilesResponse
+
+class CommitInfo(var response: CommitFilesResponse, var status: Int)
+
+object CommitInfo {
+ val COMMIT_NOTSTARTED: Int = 0
+ val COMMIT_INPROCESS: Int = 1
+ val COMMIT_FINISHED: Int = 2
+}
diff --git
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
index 18e48433..992779a8 100644
---
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
+++
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
@@ -49,6 +49,8 @@ private[deploy] class Controller(
var workerSource: WorkerSource = _
var storageManager: StorageManager = _
var shuffleMapperAttempts: ConcurrentHashMap[String, Array[Int]] = _
+ // shuffleKey -> (CommitInfo)
+ var shuffleCommitInfos: ConcurrentHashMap[String, CommitInfo] = _
var workerInfo: WorkerInfo = _
var partitionLocationInfo: PartitionLocationInfo = _
var timer: HashedWheelTimer = _
@@ -57,10 +59,13 @@ private[deploy] class Controller(
val minPartitionSizeToEstimate = conf.minPartitionSizeToEstimate
var shutdown: AtomicBoolean = _
+ val testCommitFileFailure = conf.testRetryCommitFiles
+
def init(worker: Worker): Unit = {
workerSource = worker.workerSource
storageManager = worker.storageManager
shuffleMapperAttempts = worker.shuffleMapperAttempts
+ shuffleCommitInfos = worker.shuffleCommitInfos
workerInfo = worker.workerInfo
partitionLocationInfo = worker.partitionLocationInfo
timer = worker.timer
@@ -312,7 +317,8 @@ private[deploy] class Controller(
slaveIds: jList[String],
mapAttempts: Array[Int]): Unit = {
// return null if shuffleKey does not exist
- if (!partitionLocationInfo.containsShuffle(shuffleKey)) {
+ if (!partitionLocationInfo.containsShuffle(shuffleKey) &&
!shuffleCommitInfos.containsKey(
+ shuffleKey)) {
logError(s"Shuffle $shuffleKey doesn't exist!")
context.reply(
CommitFilesResponse(
@@ -324,6 +330,45 @@ private[deploy] class Controller(
return
}
+ val shuffleCommitTimeout = conf.workerShuffleCommitTimeout
+
+ shuffleCommitInfos.putIfAbsent(shuffleKey, new CommitInfo(null,
CommitInfo.COMMIT_NOTSTARTED))
+ val commitInfo = shuffleCommitInfos.get(shuffleKey)
+
+ def waitForCommitFinish(): Unit = {
+ val delta = 100
+ var times = 0
+ while (delta * times < shuffleCommitTimeout) {
+ commitInfo.synchronized {
+ if (commitInfo.status == CommitInfo.COMMIT_FINISHED) {
+ context.reply(commitInfo.response)
+ return
+ }
+ }
+ Thread.sleep(delta)
+ times += 1
+ }
+ }
+
+ commitInfo.synchronized {
+ if (commitInfo.status == CommitInfo.COMMIT_FINISHED) {
+ logInfo(s"${shuffleKey} CommitFinished, just return the response")
+ context.reply(commitInfo.response)
+ return
+ } else if (commitInfo.status == CommitInfo.COMMIT_INPROCESS) {
+ logInfo(s"${shuffleKey} CommitFiles inprogress, wait for finish")
+ commitThreadPool.submit(new Runnable {
+ override def run(): Unit = {
+ waitForCommitFinish()
+ }
+ })
+ return
+ } else {
+ logInfo(s"Start commitFiles for ${shuffleKey}")
+ commitInfo.status = CommitInfo.COMMIT_INPROCESS
+ }
+ }
+
// close and flush files.
shuffleMapperAttempts.putIfAbsent(shuffleKey, mapAttempts)
@@ -391,10 +436,10 @@ private[deploy] class Controller(
val totalSize = partitionSizeList.asScala.sum
val fileCount = partitionSizeList.size()
// reply
- if (failedMasterIds.isEmpty && failedSlaveIds.isEmpty) {
- logInfo(s"CommitFiles for $shuffleKey success with
${committedMasterIds.size()}" +
- s" master partitions and ${committedSlaveIds.size()} slave
partitions!")
- context.reply(
+ val response =
+ if (failedMasterIds.isEmpty && failedSlaveIds.isEmpty) {
+ logInfo(s"CommitFiles for $shuffleKey success with
${committedMasterIds.size()}" +
+ s" master partitions and ${committedSlaveIds.size()} slave
partitions!")
CommitFilesResponse(
StatusCode.SUCCESS,
committedMasterIdList,
@@ -405,11 +450,10 @@ private[deploy] class Controller(
committedSlaveStorageAndDiskHintList,
committedMapIdBitMapList,
totalSize,
- fileCount))
- } else {
- logWarning(s"CommitFiles for $shuffleKey failed with
${failedMasterIds.size()} master" +
- s" partitions and ${failedSlaveIds.size()} slave partitions!")
- context.reply(
+ fileCount)
+ } else {
+ logWarning(s"CommitFiles for $shuffleKey failed with
${failedMasterIds.size()} master" +
+ s" partitions and ${failedSlaveIds.size()} slave partitions!")
CommitFilesResponse(
StatusCode.PARTIAL_SUCCESS,
committedMasterIdList,
@@ -420,13 +464,20 @@ private[deploy] class Controller(
committedSlaveStorageAndDiskHintList,
committedMapIdBitMapList,
totalSize,
- fileCount))
+ fileCount)
+ }
+ if (testCommitFileFailure) {
+ Thread.sleep(5000)
}
+ commitInfo.synchronized {
+ commitInfo.response = response
+ commitInfo.status = CommitInfo.COMMIT_FINISHED
+ }
+ context.reply(response)
}
if (future != null) {
val result = new AtomicReference[CompletableFuture[Unit]]()
- val shuffleCommitTimeout = conf.workerShuffleCommitTimeout
val timeout = timer.newTimeout(
new TimerTask {
@@ -458,6 +509,16 @@ private[deploy] class Controller(
case throwable: Throwable =>
logError("While handling commitFiles, exception occurs.",
throwable)
}
+ commitInfo.synchronized {
+ commitInfo.response = CommitFilesResponse(
+ StatusCode.COMMIT_FILE_EXCEPTION,
+ List.empty.asJava,
+ List.empty.asJava,
+ masterIds,
+ slaveIds)
+
+ commitInfo.status = CommitInfo.COMMIT_FINISHED
+ }
} else {
// finish, cancel timeout job first.
timeout.cancel()
diff --git
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala
index a0067a3a..8633795b 100644
---
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala
+++
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala
@@ -179,6 +179,8 @@ private[celeborn] class Worker(
val shuffleMapperAttempts = new ConcurrentHashMap[String, Array[Int]]()
val partitionLocationInfo = new PartitionLocationInfo
+ val shuffleCommitInfos = new ConcurrentHashMap[String, CommitInfo]()
+
private val rssHARetryClient = new RssHARetryClient(rpcEnv, conf)
// (workerInfo -> last connect timeout timestamp)
@@ -420,6 +422,7 @@ private[celeborn] class Worker(
partitionLocationInfo.removeMasterPartitions(shuffleKey)
partitionLocationInfo.removeSlavePartitions(shuffleKey)
shuffleMapperAttempts.remove(shuffleKey)
+ shuffleCommitInfos.remove(shuffleKey)
workerInfo.releaseSlots(shuffleKey)
logInfo(s"Cleaned up expired shuffle $shuffleKey")
}