Github user viirya commented on a diff in the pull request:
https://github.com/apache/spark/pull/19394#discussion_r143321349
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala ---
@@ -274,19 +274,26 @@ abstract class SparkPlan extends QueryPlan[SparkPlan]
with Logging with Serializ
val byteArrayRdd = getByteArrayRdd()
val results = ArrayBuffer[InternalRow]()
- byteArrayRdd.collect().foreach { bytes =>
- decodeUnsafeRows(bytes).foreach(results.+=)
+ byteArrayRdd.collect().foreach { rdd =>
+ decodeUnsafeRows(rdd._2).foreach(results.+=)
}
results.toArray
}
+ private[spark] def executeCollectIterator(): (Long,
Iterator[InternalRow]) = {
+ val countsAndBytes = getByteArrayRdd().collect()
+ val total = countsAndBytes.map(_._1).sum
+ val rows = countsAndBytes.iterator.flatMap(rdd =>
decodeUnsafeRows(rdd._2))
+ (total, rows)
+ }
+
/**
* Runs this query returning the result as an iterator of InternalRow.
*
* @note Triggers multiple jobs (one for each partition).
*/
def executeToIterator(): Iterator[InternalRow] = {
- getByteArrayRdd().toLocalIterator.flatMap(decodeUnsafeRows)
+ getByteArrayRdd().toLocalIterator.flatMap(rdd =>
decodeUnsafeRows(rdd._2))
--- End diff --
We don't need to collect the counts back to driver. Besides it should not
be a rdd but bytes in the `flatMap`.
Maybe:
```scala
getByteArrayRdd().map(_._2).toLocalIterator.flatMap(decodeUnsafeRows(_))
```
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]