Github user tdas commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21385#discussion_r190401710
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala
 ---
    @@ -56,20 +69,71 @@ private[shuffle] class UnsafeRowReceiver(
     
       override def receiveAndReply(context: RpcCallContext): 
PartialFunction[Any, Unit] = {
         case r: UnsafeRowReceiverMessage =>
    -      queue.put(r)
    +      queues(r.writerId).put(r)
           context.reply(())
       }
     
       override def read(): Iterator[UnsafeRow] = {
         new NextIterator[UnsafeRow] {
    -      override def getNext(): UnsafeRow = queue.take() match {
    -        case ReceiverRow(r) => r
    -        case ReceiverEpochMarker() =>
    -          finished = true
    -          null
    +      // An array of flags for whether each writer ID has gotten an epoch 
marker.
    +      private val writerEpochMarkersReceived = 
Array.fill(numShuffleWriters)(false)
    +
    +      private val executor = 
Executors.newFixedThreadPool(numShuffleWriters)
    +      private val completion = new 
ExecutorCompletionService[UnsafeRowReceiverMessage](executor)
    +
    +      private def completionTask(writerId: Int) = new 
Callable[UnsafeRowReceiverMessage] {
    +        override def call(): UnsafeRowReceiverMessage = 
queues(writerId).take()
           }
     
    -      override def close(): Unit = {}
    +      // Initialize by submitting tasks to read the first row from each 
writer.
    +      (0 until numShuffleWriters).foreach(writerId => 
completion.submit(completionTask(writerId)))
    +
    +      /**
    +       * In each call to getNext(), we pull the next row available in the 
completion queue, and then
    +       * submit another task to read the next row from the writer which 
returned it.
    +       *
    +       * When a writer sends an epoch marker, we note that it's finished 
and don't submit another
    +       * task for it in this epoch. The iterator is over once all writers 
have sent an epoch marker.
    +       */
    +      override def getNext(): UnsafeRow = {
    +        var nextRow: UnsafeRow = null
    +        while (nextRow == null) {
    +          nextRow = completion.poll(checkpointIntervalMs, 
TimeUnit.MILLISECONDS) match {
    +            case null =>
    +              // Try again if the poll didn't wait long enough to get a 
real result.
    +              // But we should be getting at least an epoch marker every 
checkpoint interval.
    +              logWarning(
    +                s"Completion service failed to make progress after 
$checkpointIntervalMs ms")
    +              null
    +
    +            // The completion service guarantees this future will be 
available immediately.
    +            case future => future.get() match {
    +              case ReceiverRow(writerId, r) =>
    +                // Start reading the next element in the queue we just 
took from.
    +                completion.submit(completionTask(writerId))
    +                r
    +              case ReceiverEpochMarker(writerId) =>
    +                // Don't read any more from this queue. If all the writers 
have sent epoch markers,
    +                // the epoch is over; otherwise we need to poll from the 
remaining writers.
    +                writerEpochMarkersReceived(writerId) = true
    +                if (writerEpochMarkersReceived.forall(flag => flag)) {
    --- End diff --
    
    super nit: `writerEpochMarkersReceived.forall(_ == true)` is easier to 
understand.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to