Github user tdas commented on a diff in the pull request:
https://github.com/apache/spark/pull/21385#discussion_r190402584
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala
---
@@ -56,20 +69,71 @@ private[shuffle] class UnsafeRowReceiver(
override def receiveAndReply(context: RpcCallContext):
PartialFunction[Any, Unit] = {
case r: UnsafeRowReceiverMessage =>
- queue.put(r)
+ queues(r.writerId).put(r)
context.reply(())
}
override def read(): Iterator[UnsafeRow] = {
new NextIterator[UnsafeRow] {
- override def getNext(): UnsafeRow = queue.take() match {
- case ReceiverRow(r) => r
- case ReceiverEpochMarker() =>
- finished = true
- null
+ // An array of flags for whether each writer ID has gotten an epoch
marker.
+ private val writerEpochMarkersReceived =
Array.fill(numShuffleWriters)(false)
+
+ private val executor =
Executors.newFixedThreadPool(numShuffleWriters)
+ private val completion = new
ExecutorCompletionService[UnsafeRowReceiverMessage](executor)
+
+ private def completionTask(writerId: Int) = new
Callable[UnsafeRowReceiverMessage] {
+ override def call(): UnsafeRowReceiverMessage =
queues(writerId).take()
}
- override def close(): Unit = {}
+ // Initialize by submitting tasks to read the first row from each
writer.
+ (0 until numShuffleWriters).foreach(writerId =>
completion.submit(completionTask(writerId)))
+
+ /**
+ * In each call to getNext(), we pull the next row available in the
completion queue, and then
+ * submit another task to read the next row from the writer which
returned it.
+ *
+ * When a writer sends an epoch marker, we note that it's finished
and don't submit another
+ * task for it in this epoch. The iterator is over once all writers
have sent an epoch marker.
+ */
+ override def getNext(): UnsafeRow = {
+ var nextRow: UnsafeRow = null
+ while (nextRow == null) {
+ nextRow = completion.poll(checkpointIntervalMs,
TimeUnit.MILLISECONDS) match {
+ case null =>
+ // Try again if the poll didn't wait long enough to get a
real result.
+ // But we should be getting at least an epoch marker every
checkpoint interval.
+ logWarning(
+ s"Completion service failed to make progress after
$checkpointIntervalMs ms")
+ null
+
+ // The completion service guarantees this future will be
available immediately.
+ case future => future.get() match {
+ case ReceiverRow(writerId, r) =>
+ // Start reading the next element in the queue we just
took from.
+ completion.submit(completionTask(writerId))
+ r
+ case ReceiverEpochMarker(writerId) =>
+ // Don't read any more from this queue. If all the writers
have sent epoch markers,
+ // the epoch is over; otherwise we need to poll from the
remaining writers.
+ writerEpochMarkersReceived(writerId) = true
+ if (writerEpochMarkersReceived.forall(flag => flag)) {
+ finished = true
+ // Break out of the while loop and end the iterator.
+ return null
+ } else {
+ // Poll again for the next completion result.
+ null
+ }
--- End diff --
if you put `nextRow = newReceivedRow` inside the case ReceiverRow, then
this else clause is not needed.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]