Repository: spark Updated Branches: refs/heads/master 927e52793 -> 71c24aad3
[SPARK-25602][SQL] SparkPlan.getByteArrayRdd should not consume the input when not necessary ## What changes were proposed in this pull request? In `SparkPlan.getByteArrayRdd`, we should only call `it.hasNext` when the limit is not hit, as `iter.hasNext` may produce one row and buffer it, and cause wrong metrics. ## How was this patch tested? new tests Closes #22621 from cloud-fan/range. Authored-by: Wenchen Fan <[email protected]> Signed-off-by: Wenchen Fan <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/71c24aad Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/71c24aad Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/71c24aad Branch: refs/heads/master Commit: 71c24aad36ae6b3f50447a019bf893490dcf1cf4 Parents: 927e527 Author: Wenchen Fan <[email protected]> Authored: Thu Oct 4 20:15:21 2018 +0800 Committer: Wenchen Fan <[email protected]> Committed: Thu Oct 4 20:15:21 2018 +0800 ---------------------------------------------------------------------- .../apache/spark/sql/execution/SparkPlan.scala | 4 +- .../sql/execution/metric/SQLMetricsSuite.scala | 55 +++++++++++++++++++- 2 files changed, 57 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/71c24aad/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index ab6031c..9d9b020 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -250,7 +250,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val codec = CompressionCodec.createCodec(SparkEnv.get.conf) val bos = new ByteArrayOutputStream() val out = new DataOutputStream(codec.compressedOutputStream(bos)) - while (iter.hasNext && (n < 0 || count < n)) { + // `iter.hasNext` may produce one row and buffer it, we should only call it when the limit is + // not hit. + while ((n < 0 || count < n) && iter.hasNext) { val row = iter.next().asInstanceOf[UnsafeRow] out.writeInt(row.getSizeInBytes) row.writeToStream(out, buffer) http://git-wip-us.apache.org/repos/asf/spark/blob/71c24aad/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index d45eb0c..085a445 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -24,7 +24,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.execution.ui.SQLAppStatusStore +import org.apache.spark.sql.execution.{FilterExec, RangeExec, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -517,4 +517,57 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared test("writing data out metrics with dynamic partition: parquet") { testMetricsDynamicPartition("parquet", "parquet", "t1") } + + test("SPARK-25602: SparkPlan.getByteArrayRdd should not consume the input when not necessary") { + def checkFilterAndRangeMetrics( + df: DataFrame, + filterNumOutputs: Int, + rangeNumOutputs: Int): Unit = { + var filter: FilterExec = null + var range: RangeExec = null + val collectFilterAndRange: SparkPlan => Unit = { + case f: FilterExec => + assert(filter == null, "the query should only have one Filter") + filter = f + case r: RangeExec => + assert(range == null, "the query should only have one Range") + range = r + case _ => + } + if (SQLConf.get.wholeStageEnabled) { + df.queryExecution.executedPlan.foreach { + case w: WholeStageCodegenExec => + w.child.foreach(collectFilterAndRange) + case _ => + } + } else { + df.queryExecution.executedPlan.foreach(collectFilterAndRange) + } + + assert(filter != null && range != null, "the query doesn't have Filter and Range") + assert(filter.metrics("numOutputRows").value == filterNumOutputs) + assert(range.metrics("numOutputRows").value == rangeNumOutputs) + } + + val df = spark.range(0, 3000, 1, 2).toDF().filter('id % 3 === 0) + val df2 = df.limit(2) + Seq(true, false).foreach { wholeStageEnabled => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> wholeStageEnabled.toString) { + df.collect() + checkFilterAndRangeMetrics(df, filterNumOutputs = 1000, rangeNumOutputs = 3000) + + df.queryExecution.executedPlan.foreach(_.resetMetrics()) + // For each partition, we get 2 rows. Then the Filter should produce 2 rows per-partition, + // and Range should produce 1000 rows (one batch) per-partition. Totally Filter produces + // 4 rows, and Range produces 2000 rows. + df.queryExecution.toRdd.mapPartitions(_.take(2)).collect() + checkFilterAndRangeMetrics(df, filterNumOutputs = 4, rangeNumOutputs = 2000) + + // Top-most limit will call `CollectLimitExec.executeCollect`, which will only run the first + // task, so totally the Filter produces 2 rows, and Range produces 1000 rows (one batch). + df2.collect() + checkFilterAndRangeMetrics(df2, filterNumOutputs = 2, rangeNumOutputs = 1000) + } + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
