Github user JoshRosen commented on a diff in the pull request:
https://github.com/apache/spark/pull/14854#discussion_r77118807
--- Diff: core/src/main/scala/org/apache/spark/rdd/RDD.scala ---
@@ -1331,6 +1335,103 @@ abstract class RDD[T: ClassTag](
}
}
+ private[spark] def takeOnline[R: ClassTag](
+ num: Int,
+ unpackPartition: Array[T] => Iterator[R]): Array[R] = withScope {
+ require(num >= 0, s"num cannot be negative, but got num=$num")
+ val lock = new Object()
+ val totalPartitions = partitions.length
+ var partitionsScanned = 0
+ var gotEnoughRows = false
+ // This buffer accumulates the rows to be returned.
+ val resultToReturn = new ArrayBuffer[R]
+ // In order to preserve the behavior of the old `take()`
implementation, it's important that
+ // we process partitions in order of their partition ids. Partitions
may be computed out of
+ // order. Once we have received all partitions up to partition N then
we can perform driver-side
+ // processing on partitions 1 through N to determine whether we've
received enough items.
+ val completedPartitions = new mutable.HashMap[Int, Array[T]]() // key
is partition id
+ var firstMissingPartition: Int = 0
+
+ var jobFuture: SimpleFutureAction[Unit] = null
+
+ // This callback is invoked as individual partitions complete.
+ def handleResult(taskIndex: Int, result: Array[T]): Unit =
lock.synchronized {
+ val partitionId = partitionsScanned + taskIndex
+ assert(partitionId < totalPartitions)
+ if (gotEnoughRows) {
+ logDebug(s"Ignoring result for partition $partitionId of $this
because we have enough rows")
+ } else {
+ logDebug(s"Handling result for partition $partitionId of $this")
+ // Buffer the result in case we can't process it now.
+ completedPartitions(partitionId) = result
+ if (partitionId == firstMissingPartition) {
+ while (!gotEnoughRows &&
completedPartitions.contains(firstMissingPartition)) {
+ logDebug(s"Unpacking partition $firstMissingPartition of
$this")
+ val rawPartitionData =
completedPartitions.remove(firstMissingPartition).get
+ resultToReturn ++= unpackPartition(rawPartitionData)
+ firstMissingPartition += 1
+
+ if (resultToReturn.size >= num) {
+ // We have unpacked enough results to reach the desired
number of results, so discard
+ // any remaining partitions' data:
+ completedPartitions.clear()
+ // Set a flag so that future task completion events are
ignored:
+ gotEnoughRows = true
+ // Cancel the job so we can return sooner
+ jobFuture.cancelWithoutFailing()
+ }
+ }
+ }
+ }
+ }
+
+ while (!gotEnoughRows && partitionsScanned < totalPartitions) {
+ var numPartitionsToCompute = 0
+ lock.synchronized {
+ // The number of partitions to try in this iteration. It is ok for
this number to be
+ // greater than totalParts because we actually cap it at
totalParts in runJob.
+ var numPartsToTry = 1L
+ if (partitionsScanned > 0) {
+ // If we didn't find any rows after the first iteration, just
try all partitions next.
+ // Otherwise, interpolate the number of partitions we need to
try, but overestimate it
+ // by 50%.
+ if (resultToReturn.isEmpty) {
+ numPartsToTry = totalPartitions - 1
+ } else {
+ numPartsToTry = (1.5 * num * partitionsScanned /
resultToReturn.size).toInt
+ }
+ }
+ numPartsToTry = math.max(0, numPartsToTry) // guard against
negative num of partitions
+
+ val partitionsToCompute = partitionsScanned.until(
+ math.min(partitionsScanned + numPartsToTry,
totalPartitions).toInt)
+ numPartitionsToCompute = partitionsToCompute.length
+
+ jobFuture = sc.submitJob(
+ this,
+ (it: Iterator[T]) => it.toArray,
+ partitionsToCompute,
+ handleResult,
+ resultFunc = ())
+ }
+
+ // scalastyle:off awaitresult
+ Await.result(jobFuture, Duration.Inf)
+ // scalastyle:on awaitresult
+ sparkContext.progressBar.foreach(_.finishAll())
--- End diff --
In a nutshell, the issue here is that `submitJob` (and `AsyncRDDActions`,
more generally) don't have task completion callbacks to remove the progress
bar; if we don't have this here then the console progress bar won't be hidden /
removed once the job completes. I'm not sure that it's a good idea to put this
logic into `submitJob` / `JobWaiter` itself because I feel like that might
introduce the potential for races between output being printed after the
`take()` returns and the progress bar removal's backspace sequences being sent.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]