Github user HeartSaVioR commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21385#discussion_r190129892
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala
 ---
    @@ -56,20 +69,73 @@ private[shuffle] class UnsafeRowReceiver(
     
       override def receiveAndReply(context: RpcCallContext): 
PartialFunction[Any, Unit] = {
         case r: UnsafeRowReceiverMessage =>
    -      queue.put(r)
    +      queues(r.writerId).put(r)
           context.reply(())
       }
     
       override def read(): Iterator[UnsafeRow] = {
         new NextIterator[UnsafeRow] {
    -      override def getNext(): UnsafeRow = queue.take() match {
    -        case ReceiverRow(r) => r
    -        case ReceiverEpochMarker() =>
    -          finished = true
    -          null
    +      // An array of flags for whether each writer ID has gotten an epoch 
marker.
    +      private val writerEpochMarkersReceived =
    +        mutable.Map.empty[Int, Boolean].withDefaultValue(false)
    +
    +      private val executor = 
Executors.newFixedThreadPool(numShuffleWriters)
    +      private val completion = new 
ExecutorCompletionService[UnsafeRowReceiverMessage](executor)
    +
    +      private def completionTask(writerId: Int) = new 
Callable[UnsafeRowReceiverMessage] {
    +        override def call(): UnsafeRowReceiverMessage = 
queues(writerId).take()
           }
     
    -      override def close(): Unit = {}
    +      // Initialize by submitting tasks to read the first row from each 
writer.
    +      (0 until numShuffleWriters).foreach(writerId => 
completion.submit(completionTask(writerId)))
    +
    +      /**
    +       * In each call to getNext(), we pull the next row available in the 
completion queue, and then
    +       * submit another task to read the next row from the writer which 
returned it.
    +       *
    +       * When a writer sends an epoch marker, we note that it's finished 
and don't submit another
    +       * task for it in this epoch. The iterator is over once all writers 
have sent an epoch marker.
    +       */
    +      override def getNext(): UnsafeRow = {
    +        var nextRow: UnsafeRow = null
    +        while (nextRow == null) {
    +          nextRow = completion.poll(checkpointIntervalMs, 
TimeUnit.MILLISECONDS) match {
    +            case null =>
    +              // Try again if the poll didn't wait long enough to get a 
real result.
    +              // But we should be getting at least an epoch marker every 
checkpoint interval.
    +              logWarning(
    +                s"Completion service failed to make progress after 
$checkpointIntervalMs ms")
    +              null
    +
    +            // The completion service guarantees this future will be 
available immediately.
    +            case future => future.get() match {
    +              case ReceiverRow(writerId, r) =>
    +                // Start reading the next element in the queue we just 
took from.
    +                completion.submit(completionTask(writerId))
    +                r
    +              // TODO use writerId
    --- End diff --
    
    It looks like to be not needed as of now.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to