Github user cloud-fan commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21898#discussion_r207704051
  
    --- Diff: core/src/main/scala/org/apache/spark/BarrierCoordinator.scala ---
    @@ -0,0 +1,233 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark
    +
    +import java.util.{Timer, TimerTask}
    +import java.util.concurrent.ConcurrentHashMap
    +import java.util.function.Consumer
    +
    +import scala.collection.mutable.ArrayBuffer
    +
    +import org.apache.spark.internal.Logging
    +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
    +import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, 
SparkListenerStageCompleted}
    +
    +/**
    + * Only one barrier() call shall happen on a barrier stage attempt at each 
time, we can use
    + * (stageId, stageAttemptId) to identify the stage attempt where the 
barrier() call is from.
    + */
    +private case class ContextBarrierId(stageId: Int, stageAttemptId: Int) {
    +  override def toString: String = s"Stage $stageId (Attempt 
$stageAttemptId)"
    +}
    +
    +/**
    + * A coordinator that handles all global sync requests from 
BarrierTaskContext. Each global sync
    + * request is generated by `BarrierTaskContext.barrier()`, and identified 
by
    + * stageId + stageAttemptId + barrierEpoch. Reply all the blocking global 
sync requests upon
    + * received all the requests for a group of `barrier()` calls. If the 
coordinator doesn't collect
    + * enough global sync requests within a configured time, fail all the 
requests due to timeout.
    + */
    +private[spark] class BarrierCoordinator(
    +    timeout: Int,
    +    listenerBus: LiveListenerBus,
    +    override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with 
Logging {
    +
    +  private val timer = new Timer("BarrierCoordinator barrier epoch 
increment timer")
    +
    +  // Listen to StageCompleted event, clear corresponding 
ContextBarrierState.
    +  private val listener = new SparkListener {
    +    override def onStageCompleted(stageCompleted: 
SparkListenerStageCompleted): Unit = {
    +      val stageInfo = stageCompleted.stageInfo
    +      val barrierId = ContextBarrierId(stageInfo.stageId, 
stageInfo.attemptNumber)
    +      // Clear ContextBarrierState from a finished stage attempt.
    +      val barrierState = states.remove(barrierId)
    +      if (barrierState != null) {
    +        barrierState.clear()
    +      }
    +    }
    +  }
    +
    +  // Remember all active stage attempts that make barrier() call(s), and 
the corresponding
    +  // internal state.
    +  private val states = new ConcurrentHashMap[ContextBarrierId, 
ContextBarrierState]
    +
    +  override def onStart(): Unit = {
    +    super.onStart()
    +    listenerBus.addToStatusQueue(listener)
    +  }
    +
    +  /**
    +   * Provide current state of a barrier() call, the state is created when 
a new stage attempt send
    +   * out a barrier() call, and recycled on stage completed.
    +   *
    +   * @param barrierId Identifier of the barrier stage that make a 
barrier() call.
    +   * @param numTasks Number of tasks of the barrier stage, all barrier() 
calls from the stage shall
    +   *                 collect `numTasks` requests to succeed.
    +   */
    +  private class ContextBarrierState(
    +      val barrierId: ContextBarrierId,
    +      val numTasks: Int) {
    +
    +    // There may be multiple barrier() calls from a barrier stage attempt, 
`barrierEpoch` is used
    +    // to identify each barrier() call. It shall get increased when a 
barrier() call succeed, or
    +    // reset when a barrier() call fail due to timeout.
    +    private var barrierEpoch: Int = 0
    +
    +    // An array of RPCCallContexts for barrier tasks that are waiting for 
reply of a barrier()
    +    // call.
    +    private val requesters: ArrayBuffer[RpcCallContext] = new 
ArrayBuffer[RpcCallContext](numTasks)
    +
    +    // A timer task that ensures we may timeout for a barrier() call.
    +    private var timerTask: TimerTask = null
    +
    +    // Init a TimerTask for a barrier() call.
    +    private def initTimerTask(): Unit = {
    +      timerTask = new TimerTask {
    +        override def run(): Unit = {
    +          // Timeout current barrier() call, fail all the sync requests.
    +          failAllRequesters(requesters, "The coordinator didn't get all 
barrier sync " +
    +            s"requests for barrier epoch $barrierEpoch from $barrierId 
within ${timeout}s.")
    +          cleanupBarrierStage(barrierId)
    +        }
    +      }
    +    }
    +
    +    // Cancel the current active TimerTask and release resources.
    +    private def cancelTimerTask(): Unit = {
    +      if (timerTask != null) {
    +        timerTask.cancel()
    +        timerTask = null
    +      }
    +    }
    +
    +    // Process the global sync request. The barrier() call succeed if 
collected enough requests
    +    // within a configured time, otherwise fail all the pending requests.
    +    def handleRequest(requester: RpcCallContext, request: RequestToSync): 
Unit = synchronized {
    +      val taskId = request.taskAttemptId
    +      val epoch = request.barrierEpoch
    +
    +      // Require the number of tasks is correctly set from the 
BarrierTaskContext.
    +      require(request.numTasks == numTasks, s"Number of tasks of 
$barrierId is " +
    +        s"${request.numTasks} from Task $taskId, previously it was 
$numTasks.")
    +
    +      // Check whether the epoch from the barrier tasks matches current 
barrierEpoch.
    +      logInfo(s"Current barrier epoch for $barrierId is $barrierEpoch.")
    +      if (epoch != barrierEpoch) {
    +        requester.sendFailure(new SparkException(s"The request to sync of 
$barrierId with " +
    +          s"barrier epoch $barrierEpoch has already finished. Maybe task 
$taskId is not " +
    +          "properly killed."))
    +      } else {
    +        // If this is the first sync message received for a barrier() 
call, start timer to ensure
    +        // we may timeout for the sync.
    --- End diff --
    
    We create `ContextBarrierState` when we receive the first sync message, I 
think it's more clear to create the timer when creating `ContextBarrierState`, 
so that we don't need the `if` here.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to