Ngone51 commented on a change in pull request #27395: [SPARK-30667][CORE] Add
allGather method to BarrierTaskContext
URL: https://github.com/apache/spark/pull/27395#discussion_r377708416
##########
File path: core/src/main/scala/org/apache/spark/BarrierCoordinator.scala
##########
@@ -153,27 +180,69 @@ private[spark] class BarrierCoordinator(
}
// Add the requester to array of RPCCallContexts pending for reply.
requesters += requester
+ allGatherMessages += allGatherMessage
logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received
update from Task " +
- s"$taskId, current progress: ${requesters.size}/$numTasks.")
- if (maybeFinishAllRequesters(requesters, numTasks)) {
+ s"$taskAttemptId, current progress: ${requesters.size}/$numTasks.")
+ if (maybeFinishAllRequesters(requesters, numTasks, requestMethod)) {
// Finished current barrier() call successfully, clean up
ContextBarrierState and
// increase the barrier epoch.
logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received
all updates from " +
s"tasks, finished successfully.")
barrierEpoch += 1
requesters.clear()
+ allGatherMessages.clear()
cancelTimerTask()
}
}
}
+ def handleBarrierRequest(
+ requester: RpcCallContext,
+ request: BarrierRequestToSync
+ ): Unit = synchronized {
+ handleRequest(
+ requester,
+ request.numTasks,
+ request.stageId,
+ request.taskAttemptId,
+ request.barrierEpoch,
+ request.requestMethod
+ )
+ }
+
+ def handleAllGatherRequest(
+ requester: RpcCallContext,
+ request: AllGatherRequestToSync
+ ): Unit = synchronized {
+ handleRequest(
+ requester,
+ request.numTasks,
+ request.stageId,
+ request.taskAttemptId,
+ request.barrierEpoch,
+ request.requestMethod,
+ request.allGatherMessage
+ )
+ }
+
// Finish all the blocking barrier sync requests from a stage attempt
successfully if we
// have received all the sync requests.
private def maybeFinishAllRequesters(
requesters: ArrayBuffer[RpcCallContext],
- numTasks: Int): Boolean = {
+ numTasks: Int,
+ requestMethod: RequestMethod.Value): Boolean = {
if (requesters.size == numTasks) {
- requesters.foreach(_.reply(()))
+ if (requestMethod == RequestMethod.BARRIER) {
+ requesters.foreach(_.reply(Array[Byte]()))
+ }
+ else if (requestMethod == RequestMethod.ALL_GATHER) {
+ val msgsArray: Array[Array[Byte]] = allGatherMessages.toArray
+ val b = new ByteArrayOutputStream();
Review comment:
Unnecessary semicolon, here and below.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]