Github user rdblue commented on a diff in the pull request:
https://github.com/apache/spark/pull/19394#discussion_r143226714
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
---
@@ -73,25 +73,37 @@ case class BroadcastExchangeExec(
try {
val beforeCollect = System.nanoTime()
// Note that we use .executeCollect() because we don't want to
convert data to Scala types
- val input: Array[InternalRow] = child.executeCollect()
- if (input.length >= 512000000) {
+ val (numRows, input) = child.executeCollectIterator()
+ if (numRows >= 512000000) {
throw new SparkException(
- s"Cannot broadcast the table with more than 512 millions
rows: ${input.length} rows")
+ s"Cannot broadcast the table with more than 512 millions
rows: $numRows rows")
}
+
val beforeBuild = System.nanoTime()
longMetric("collectTime") += (beforeBuild - beforeCollect) /
1000000
- val dataSize =
input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
+
+ // Construct the relation.
+ val relation = mode.transform(input, Some(numRows))
+
+ val dataSize = relation match {
+ case map: HashedRelation =>
+ map.estimatedSize
+ case arr: Array[InternalRow] =>
+ arr.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
+ case _ =>
+ numRows * 512 // guess: each row is about 512 bytes
--- End diff --
The won't cause a regression because all the broadcast modes return either
Array or HashedRelation. This is just in case there is a path that returns
something different in the future. Maybe it would be better to throw an
exception here. What do you think?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]