Repository: spark
Updated Branches:
  refs/heads/master 927e52793 -> 71c24aad3


[SPARK-25602][SQL] SparkPlan.getByteArrayRdd should not consume the input when 
not necessary

## What changes were proposed in this pull request?

In `SparkPlan.getByteArrayRdd`, we should only call `it.hasNext` when the limit 
is not hit, as `iter.hasNext` may produce one row and buffer it, and cause 
wrong metrics.

## How was this patch tested?

new tests

Closes #22621 from cloud-fan/range.

Authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/71c24aad
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/71c24aad
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/71c24aad

Branch: refs/heads/master
Commit: 71c24aad36ae6b3f50447a019bf893490dcf1cf4
Parents: 927e527
Author: Wenchen Fan <[email protected]>
Authored: Thu Oct 4 20:15:21 2018 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Thu Oct 4 20:15:21 2018 +0800

----------------------------------------------------------------------
 .../apache/spark/sql/execution/SparkPlan.scala  |  4 +-
 .../sql/execution/metric/SQLMetricsSuite.scala  | 55 +++++++++++++++++++-
 2 files changed, 57 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/71c24aad/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index ab6031c..9d9b020 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -250,7 +250,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with 
Logging with Serializ
       val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
       val bos = new ByteArrayOutputStream()
       val out = new DataOutputStream(codec.compressedOutputStream(bos))
-      while (iter.hasNext && (n < 0 || count < n)) {
+      // `iter.hasNext` may produce one row and buffer it, we should only call 
it when the limit is
+      // not hit.
+      while ((n < 0 || count < n) && iter.hasNext) {
         val row = iter.next().asInstanceOf[UnsafeRow]
         out.writeInt(row.getSizeInBytes)
         row.writeToStream(out, buffer)

http://git-wip-us.apache.org/repos/asf/spark/blob/71c24aad/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index d45eb0c..085a445 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -24,7 +24,7 @@ import scala.util.Random
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.execution.ui.SQLAppStatusStore
+import org.apache.spark.sql.execution.{FilterExec, RangeExec, SparkPlan, 
WholeStageCodegenExec}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
@@ -517,4 +517,57 @@ class SQLMetricsSuite extends SparkFunSuite with 
SQLMetricsTestUtils with Shared
   test("writing data out metrics with dynamic partition: parquet") {
     testMetricsDynamicPartition("parquet", "parquet", "t1")
   }
+
+  test("SPARK-25602: SparkPlan.getByteArrayRdd should not consume the input 
when not necessary") {
+    def checkFilterAndRangeMetrics(
+        df: DataFrame,
+        filterNumOutputs: Int,
+        rangeNumOutputs: Int): Unit = {
+      var filter: FilterExec = null
+      var range: RangeExec = null
+      val collectFilterAndRange: SparkPlan => Unit = {
+        case f: FilterExec =>
+          assert(filter == null, "the query should only have one Filter")
+          filter = f
+        case r: RangeExec =>
+          assert(range == null, "the query should only have one Range")
+          range = r
+        case _ =>
+      }
+      if (SQLConf.get.wholeStageEnabled) {
+        df.queryExecution.executedPlan.foreach {
+          case w: WholeStageCodegenExec =>
+            w.child.foreach(collectFilterAndRange)
+          case _ =>
+        }
+      } else {
+        df.queryExecution.executedPlan.foreach(collectFilterAndRange)
+      }
+
+      assert(filter != null && range != null, "the query doesn't have Filter 
and Range")
+      assert(filter.metrics("numOutputRows").value == filterNumOutputs)
+      assert(range.metrics("numOutputRows").value == rangeNumOutputs)
+    }
+
+    val df = spark.range(0, 3000, 1, 2).toDF().filter('id % 3 === 0)
+    val df2 = df.limit(2)
+    Seq(true, false).foreach { wholeStageEnabled =>
+      withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> 
wholeStageEnabled.toString) {
+        df.collect()
+        checkFilterAndRangeMetrics(df, filterNumOutputs = 1000, 
rangeNumOutputs = 3000)
+
+        df.queryExecution.executedPlan.foreach(_.resetMetrics())
+        // For each partition, we get 2 rows. Then the Filter should produce 2 
rows per-partition,
+        // and Range should produce 1000 rows (one batch) per-partition. 
Totally Filter produces
+        // 4 rows, and Range produces 2000 rows.
+        df.queryExecution.toRdd.mapPartitions(_.take(2)).collect()
+        checkFilterAndRangeMetrics(df, filterNumOutputs = 4, rangeNumOutputs = 
2000)
+
+        // Top-most limit will call `CollectLimitExec.executeCollect`, which 
will only run the first
+        // task, so totally the Filter produces 2 rows, and Range produces 
1000 rows (one batch).
+        df2.collect()
+        checkFilterAndRangeMetrics(df2, filterNumOutputs = 2, rangeNumOutputs 
= 1000)
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to