Github user zsxwing commented on a diff in the pull request:

    https://github.com/apache/spark/pull/7276#discussion_r35227189
  
    --- Diff: 
streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
 ---
    @@ -251,171 +312,238 @@ class ReceiverTracker(ssc: StreamingContext, 
skipReceiverLaunch: Boolean = false
         logWarning(s"Error reported by receiver for stream $streamId: 
$messageWithError")
       }
     
    +  private def scheduleReceiver(receiverId: Int): Seq[String] = {
    +    val preferredLocation = 
receiverPreferredLocations.getOrElse(receiverId, None)
    +    val scheduledLocations = schedulingPolicy.rescheduleReceiver(
    +      receiverId, preferredLocation, receiverTrackingInfos, getExecutors)
    +    updateReceiverScheduledLocations(receiverId, scheduledLocations)
    +    scheduledLocations
    +  }
    +
    +  private def updateReceiverScheduledLocations(
    +      receiverId: Int, scheduledLocations: Seq[String]): Unit = {
    +    val newReceiverTrackingInfo = receiverTrackingInfos.get(receiverId) 
match {
    +      case Some(oldInfo) =>
    +        oldInfo.copy(state = ReceiverState.SCHEDULED,
    +          scheduledLocations = Some(scheduledLocations))
    +      case None =>
    +        ReceiverTrackingInfo(
    +          receiverId,
    +          ReceiverState.SCHEDULED,
    +          Some(scheduledLocations),
    +          None)
    +    }
    +    receiverTrackingInfos.put(receiverId, newReceiverTrackingInfo)
    +  }
    +
       /** Check if any blocks are left to be processed */
       def hasUnallocatedBlocks: Boolean = {
         receivedBlockTracker.hasUnallocatedReceivedBlocks
       }
     
    +  /**
    +   * Get the list of executors excluding driver
    +   */
    +  private def getExecutors: Seq[String] = {
    +    if (ssc.sc.isLocal) {
    +      Seq(ssc.sparkContext.env.blockManager.blockManagerId.hostPort)
    +    } else {
    +      ssc.sparkContext.env.blockManager.master.getMemoryStatus.filter { 
case (blockManagerId, _) =>
    +        blockManagerId.executorId != SparkContext.DRIVER_IDENTIFIER // 
Ignore the driver location
    +      }.map { case (blockManagerId, _) => blockManagerId.hostPort }.toSeq
    +    }
    +  }
    +
    +  /**
    +   * Run the dummy Spark job to ensure that all slaves have registered. 
This avoids all the
    +   * receivers to be scheduled on the same node.
    +   *
    +   * TODO Should poll the executor number and wait for executors according 
to
    +   * "spark.scheduler.minRegisteredResourcesRatio" and
    +   * "spark.scheduler.maxRegisteredResourcesWaitingTime" rather than 
running a dummy job.
    +   */
    +  private def runDummySparkJob(): Unit = {
    +    if (!ssc.sparkContext.isLocal) {
    +      ssc.sparkContext.makeRDD(1 to 50, 50).map(x => (x, 1)).reduceByKey(_ 
+ _, 20).collect()
    +    }
    +    assert(getExecutors.nonEmpty)
    +  }
    +
    +  /**
    +   * Get the receivers from the ReceiverInputDStreams, distributes them to 
the
    +   * worker nodes as a parallel collection, and runs them.
    +   */
    +  private def launchReceivers(): Unit = {
    +    val receivers = receiverInputStreams.map(nis => {
    +      val rcvr = nis.getReceiver()
    +      rcvr.setReceiverId(nis.id)
    +      rcvr
    +    })
    +
    +    runDummySparkJob()
    +
    +    logInfo("Starting " + receivers.length + " receivers")
    +    endpoint.send(StartAllReceivers(receivers))
    +  }
    +
    +  /** Check if tracker has been marked for starting */
    +  private def isTrackerStarted: Boolean = trackerState == Started
    +
    +  /** Check if tracker has been marked for stopping */
    +  private def isTrackerStopping: Boolean = trackerState == Stopping
    +
    +  /** Check if tracker has been marked for stopped */
    +  private def isTrackerStopped: Boolean = trackerState == Stopped
    +
       /** RpcEndpoint to receive messages from the receivers. */
       private class ReceiverTrackerEndpoint(override val rpcEnv: RpcEnv) 
extends ThreadSafeRpcEndpoint {
     
    +    // TODO Remove this thread pool after 
https://github.com/apache/spark/issues/7385 is merged
    +    private val submitJobThreadPool = ExecutionContext.fromExecutorService(
    +      ThreadUtils.newDaemonCachedThreadPool("submit-job-thead-pool"))
    +
         override def receive: PartialFunction[Any, Unit] = {
    +      // Local messages
    +      case StartAllReceivers(receivers) =>
    +        val scheduledLocations = 
schedulingPolicy.scheduleReceivers(receivers, getExecutors)
    +        for (receiver <- receivers) {
    +          val locations = scheduledLocations(receiver.streamId)
    +          updateReceiverScheduledLocations(receiver.streamId, locations)
    +          receiverPreferredLocations(receiver.streamId) = 
receiver.preferredLocation
    +          startReceiver(receiver, locations)
    +        }
    +      case RestartReceiver(receiver) =>
    +        val scheduledLocations = schedulingPolicy.rescheduleReceiver(
    +          receiver.streamId,
    +          receiver.preferredLocation,
    +          receiverTrackingInfos,
    +          getExecutors)
    +        updateReceiverScheduledLocations(receiver.streamId, 
scheduledLocations)
    +        startReceiver(receiver, scheduledLocations)
    +      case c @ CleanupOldBlocks(cleanupThreshTime) =>
    +        receiverTrackingInfos.values.flatMap(_.endpoint).foreach(_.send(c))
    +      // Remote messages
           case ReportError(streamId, message, error) =>
             reportError(streamId, message, error)
         }
     
         override def receiveAndReply(context: RpcCallContext): 
PartialFunction[Any, Unit] = {
    -      case RegisterReceiver(streamId, typ, host, receiverEndpoint) =>
    +      // Remote messages
    +      case RegisterReceiver(streamId, typ, hostPort, receiverEndpoint) =>
             val successful =
    -          registerReceiver(streamId, typ, host, receiverEndpoint, 
context.sender.address)
    +          registerReceiver(streamId, typ, hostPort, receiverEndpoint, 
context.sender.address)
             context.reply(successful)
           case AddBlock(receivedBlockInfo) =>
             context.reply(addBlock(receivedBlockInfo))
           case DeregisterReceiver(streamId, message, error) =>
             deregisterReceiver(streamId, message, error)
             context.reply(true)
    +      // Local messages
    +      case AllReceiverIds =>
    +        context.reply(receiverTrackingInfos.keys.toSeq)
           case StopAllReceivers =>
             assert(isTrackerStopping || isTrackerStopped)
             stopReceivers()
             context.reply(true)
         }
     
    -    /** Send stop signal to the receivers. */
    -    private def stopReceivers() {
    -      // Signal the receivers to stop
    -      receiverInfo.values.flatMap { info => Option(info.endpoint)}
    -        .foreach { _.send(StopReceiver) }
    -      logInfo("Sent stop signal to all " + receiverInfo.size + " 
receivers")
    -    }
    -  }
    -
    -  /** This thread class runs all the receivers on the cluster.  */
    -  class ReceiverLauncher {
    -    @transient val env = ssc.env
    -    @volatile @transient var running = false
    -    @transient val thread = new Thread() {
    -      override def run() {
    -        try {
    -          SparkEnv.set(env)
    -          startReceivers()
    -        } catch {
    -          case ie: InterruptedException => logInfo("ReceiverLauncher 
interrupted")
    -        }
    -      }
    -    }
    -
    -    def start() {
    -      thread.start()
    -    }
    -
         /**
    -     * Get the list of executors excluding driver
    -     */
    -    private def getExecutors(ssc: StreamingContext): List[String] = {
    -      val executors = 
ssc.sparkContext.getExecutorMemoryStatus.map(_._1.split(":")(0)).toList
    -      val driver = ssc.sparkContext.getConf.get("spark.driver.host")
    -      executors.diff(List(driver))
    -    }
    -
    -    /** Set host location(s) for each receiver so as to distribute them 
over
    -     * executors in a round-robin fashion taking into account 
preferredLocation if set
    +     * Start a receiver along with its scheduled locations
          */
    -    private[streaming] def scheduleReceivers(receivers: Seq[Receiver[_]],
    -      executors: List[String]): Array[ArrayBuffer[String]] = {
    -      val locations = new Array[ArrayBuffer[String]](receivers.length)
    -      var i = 0
    -      for (i <- 0 until receivers.length) {
    -        locations(i) = new ArrayBuffer[String]()
    -        if (receivers(i).preferredLocation.isDefined) {
    -          locations(i) += receivers(i).preferredLocation.get
    -        }
    -      }
    -      var count = 0
    -      for (i <- 0 until max(receivers.length, executors.length)) {
    -        if (!receivers(i % receivers.length).preferredLocation.isDefined) {
    -          locations(i % receivers.length) += executors(count)
    -          count += 1
    -          if (count == executors.length) {
    -            count = 0
    -          }
    -        }
    +    private def startReceiver(receiver: Receiver[_], scheduledLocations: 
Seq[String]): Unit = {
    +      val receiverId = receiver.streamId
    +      if (!isTrackerStarted) {
    +        onReceiverJobFinish(receiverId)
    +        return
           }
    -      locations
    -    }
    -
    -    /**
    -     * Get the receivers from the ReceiverInputDStreams, distributes them 
to the
    -     * worker nodes as a parallel collection, and runs them.
    -     */
    -    private def startReceivers() {
    -      val receivers = receiverInputStreams.map(nis => {
    -        val rcvr = nis.getReceiver()
    -        rcvr.setReceiverId(nis.id)
    -        rcvr
    -      })
     
           val checkpointDirOption = Option(ssc.checkpointDir)
           val serializableHadoopConf =
             new SerializableConfiguration(ssc.sparkContext.hadoopConfiguration)
     
           // Function to start the receiver on the worker node
    -      val startReceiver = (iterator: Iterator[Receiver[_]]) => {
    -        if (!iterator.hasNext) {
    -          throw new SparkException(
    -            "Could not start receiver as object not found.")
    -        }
    -        val receiver = iterator.next()
    -        val supervisor = new ReceiverSupervisorImpl(
    -          receiver, SparkEnv.get, serializableHadoopConf.value, 
checkpointDirOption)
    -        supervisor.start()
    -        supervisor.awaitTermination()
    -      }
    +      val startReceiverFunc = new StartReceiverFunc(checkpointDirOption, 
serializableHadoopConf)
     
    -      // Run the dummy Spark job to ensure that all slaves have registered.
    -      // This avoids all the receivers to be scheduled on the same node.
    -      if (!ssc.sparkContext.isLocal) {
    -        ssc.sparkContext.makeRDD(1 to 50, 50).map(x => (x, 
1)).reduceByKey(_ + _, 20).collect()
    -      }
    -
    -      // Get the list of executors and schedule receivers
    -      val executors = getExecutors(ssc)
    -      val tempRDD =
    -        if (!executors.isEmpty) {
    -          val locations = scheduleReceivers(receivers, executors)
    -          val roundRobinReceivers = (0 until receivers.length).map(i =>
    -            (receivers(i), locations(i)))
    -          ssc.sc.makeRDD[Receiver[_]](roundRobinReceivers)
    +      // Create the RDD using the scheduledLocations to run the receiver 
in a Spark job
    +      val receiverRDD: RDD[Receiver[_]] =
    +        if (scheduledLocations.isEmpty) {
    +          ssc.sc.makeRDD(Seq(receiver), 1)
             } else {
    -          ssc.sc.makeRDD(receivers, receivers.size)
    +          ssc.sc.makeRDD(Seq(receiver -> scheduledLocations))
             }
    +      receiverRDD.setName(s"Receiver $receiverId")
    +      val future = ssc.sparkContext.submitJob[Receiver[_], Unit, Unit](
    +        receiverRDD, startReceiverFunc, Seq(0), (_, _) => Unit, ())
    +      // We will keep restarting the receiver job until ReceiverTracker is 
stopped
    +      future.onComplete {
    +        case Success(_) =>
    +          if (!isTrackerStarted) {
    +            onReceiverJobFinish(receiverId)
    +          } else {
    +            logInfo(s"Restarting Receiver $receiverId")
    +            self.send(RestartReceiver(receiver))
    +          }
    +        case Failure(e) =>
    +          if (!isTrackerStarted) {
    +            onReceiverJobFinish(receiverId)
    +          } else {
    +            logError("Receiver has been stopped. Try to restart it.", e)
    +            logInfo(s"Restarting Receiver $receiverId")
    +            self.send(RestartReceiver(receiver))
    +          }
    +      }(submitJobThreadPool)
    +      logInfo(s"Receiver ${receiver.streamId} started")
    +    }
     
    -      // Distribute the receivers and start them
    -      logInfo("Starting " + receivers.length + " receivers")
    -      running = true
    -      try {
    -        ssc.sparkContext.runJob(tempRDD, 
ssc.sparkContext.clean(startReceiver))
    -        logInfo("All of the receivers have been terminated")
    -      } finally {
    -        running = false
    -      }
    +    override def onStop(): Unit = {
    +      submitJobThreadPool.shutdownNow()
         }
     
         /**
    -     * Wait until the Spark job that runs the receivers is terminated, or 
return when
    -     * `milliseconds` elapses
    +     * Call when a receiver is terminated. It means we won't restart its 
Spark job.
          */
    -    def awaitTermination(milliseconds: Long): Unit = {
    -      thread.join(milliseconds)
    +    private def onReceiverJobFinish(receiverId: Int): Unit = {
    +      receiverJobExitLatch.countDown()
    +      receiverTrackingInfos.remove(receiverId).foreach { 
receiverTrackingInfo =>
    +        if (receiverTrackingInfo.state == ReceiverState.ACTIVE) {
    +          logWarning(s"Receiver $receiverId exited but didn't deregister")
    +        }
    +      }
         }
    -  }
     
    -  /** Check if tracker has been marked for starting */
    -  private def isTrackerStarted(): Boolean = trackerState == Started
    --- End diff --
    
    `isTrackerStarted`, `isTrackerStopping` and `isTrackerStopped` is moved 
avoid ReceiverTrackerEndpoint so that there is no method of `ReceiverTracker` 
below `ReceiverTrackerEndpoint`. 


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to