Github user kiszk commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21898#discussion_r207743465
  
    --- Diff: core/src/main/scala/org/apache/spark/BarrierCoordinator.scala ---
    @@ -0,0 +1,230 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark
    +
    +import java.util.{Timer, TimerTask}
    +import java.util.concurrent.ConcurrentHashMap
    +import java.util.function.Consumer
    +
    +import scala.collection.mutable.ArrayBuffer
    +
    +import org.apache.spark.internal.Logging
    +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
    +import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, 
SparkListenerStageCompleted}
    +
    +/**
    + * For each barrier stage attempt, only at most one barrier() call can be 
active at any time, thus
    + * we can use (stageId, stageAttemptId) to identify the stage attempt 
where the barrier() call is
    + * from.
    + */
    +private case class ContextBarrierId(stageId: Int, stageAttemptId: Int) {
    +  override def toString: String = s"Stage $stageId (Attempt 
$stageAttemptId)"
    +}
    +
    +/**
    + * A coordinator that handles all global sync requests from 
BarrierTaskContext. Each global sync
    + * request is generated by `BarrierTaskContext.barrier()`, and identified 
by
    + * stageId + stageAttemptId + barrierEpoch. Reply all the blocking global 
sync requests upon
    + * all the requests for a group of `barrier()` calls are received. If the 
coordinator is unable to
    + * collect enough global sync requests within a configured time, fail all 
the requests and return
    + * an Exception with timeout message.
    + */
    +private[spark] class BarrierCoordinator(
    +    timeoutInSecs: Long,
    +    listenerBus: LiveListenerBus,
    +    override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with 
Logging {
    +
    +  private val timer = new Timer("BarrierCoordinator barrier epoch 
increment timer")
    +
    +  // Listen to StageCompleted event, clear corresponding 
ContextBarrierState.
    +  private val listener = new SparkListener {
    +    override def onStageCompleted(stageCompleted: 
SparkListenerStageCompleted): Unit = {
    +      val stageInfo = stageCompleted.stageInfo
    +      val barrierId = ContextBarrierId(stageInfo.stageId, 
stageInfo.attemptNumber)
    +      // Clear ContextBarrierState from a finished stage attempt.
    +      cleanupBarrierStage(barrierId)
    +    }
    +  }
    +
    +  // Record all active stage attempts that make barrier() call(s), and the 
corresponding internal
    +  // state.
    +  private val states = new ConcurrentHashMap[ContextBarrierId, 
ContextBarrierState]
    +
    +  override def onStart(): Unit = {
    +    super.onStart()
    +    listenerBus.addToStatusQueue(listener)
    +  }
    +
    +  override def onStop(): Unit = {
    +    try {
    +      states.forEachValue(1, clearStateConsumer)
    +      states.clear()
    +      listenerBus.removeListener(listener)
    +    } finally {
    +      super.onStop()
    +    }
    +  }
    +
    +  /**
    +   * Provide the current state of a barrier() call. A state is created 
when a new stage attempt
    +   * sends out a barrier() call, and recycled on stage completed.
    +   *
    +   * @param barrierId Identifier of the barrier stage that make a 
barrier() call.
    +   * @param numTasks Number of tasks of the barrier stage, all barrier() 
calls from the stage shall
    +   *                 collect `numTasks` requests to succeed.
    +   */
    +  private class ContextBarrierState(
    +      val barrierId: ContextBarrierId,
    +      val numTasks: Int) {
    +
    +    // There may be multiple barrier() calls from a barrier stage attempt, 
`barrierEpoch` is used
    +    // to identify each barrier() call. It shall get increased when a 
barrier() call succeed, or
    +    // reset when a barrier() call fail due to timeout.
    --- End diff --
    
    nit: `fail` -> `fails`


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to