Repository: spark
Updated Branches:
  refs/heads/master 729ce3703 -> b486ffc86


[SPARK-19447] Make Range operator generate "recordsRead" metric

## What changes were proposed in this pull request?

The Range was modified to produce "recordsRead" metric instead of "generated 
rows". The tests were updated and partially moved to SQLMetricsSuite.

## How was this patch tested?

Unit tests.

Author: Ala Luszczak <[email protected]>

Closes #16960 from ala/range-records-read.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b486ffc8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b486ffc8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b486ffc8

Branch: refs/heads/master
Commit: b486ffc86d8ad6c303321dcf8514afee723f61f8
Parents: 729ce37
Author: Ala Luszczak <[email protected]>
Authored: Sat Feb 18 07:51:41 2017 -0800
Committer: Reynold Xin <[email protected]>
Committed: Sat Feb 18 07:51:41 2017 -0800

----------------------------------------------------------------------
 .../expressions/codegen/CodeGenerator.scala     |   4 +-
 .../sql/execution/basicPhysicalOperators.scala  |  12 +-
 .../InputGeneratedOutputMetricsSuite.scala      | 131 -------------------
 .../sql/execution/metric/SQLMetricsSuite.scala  | 104 +++++++++++++++
 .../org/apache/spark/sql/jdbc/JDBCSuite.scala   |  11 +-
 .../sql/hive/execution/HiveSerDeSuite.scala     |  18 +--
 6 files changed, 125 insertions(+), 155 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b486ffc8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 87932e0..760ead4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -31,6 +31,7 @@ import org.codehaus.janino.{ByteArrayClassLoader, 
ClassBodyEvaluator, SimpleComp
 import org.codehaus.janino.util.ClassFile
 
 import org.apache.spark.{SparkEnv, TaskContext, TaskKilledException}
+import org.apache.spark.executor.InputMetrics
 import org.apache.spark.internal.Logging
 import org.apache.spark.metrics.source.CodegenMetrics
 import org.apache.spark.sql.catalyst.InternalRow
@@ -933,7 +934,8 @@ object CodeGenerator extends Logging {
       classOf[UnsafeMapData].getName,
       classOf[Expression].getName,
       classOf[TaskContext].getName,
-      classOf[TaskKilledException].getName
+      classOf[TaskKilledException].getName,
+      classOf[InputMetrics].getName
     ))
     evaluator.setExtendedClass(classOf[GeneratedClass])
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b486ffc8/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index c01f9c5..87e90ed 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -365,6 +365,9 @@ case class RangeExec(range: 
org.apache.spark.sql.catalyst.plans.logical.Range)
 
     val taskContext = ctx.freshName("taskContext")
     ctx.addMutableState("TaskContext", taskContext, s"$taskContext = 
TaskContext.get();")
+    val inputMetrics = ctx.freshName("inputMetrics")
+    ctx.addMutableState("InputMetrics", inputMetrics,
+        s"$inputMetrics = $taskContext.taskMetrics().inputMetrics();")
 
     // In order to periodically update the metrics without inflicting 
performance penalty, this
     // operator produces elements in batches. After a batch is complete, the 
metrics are updated
@@ -460,7 +463,7 @@ case class RangeExec(range: 
org.apache.spark.sql.catalyst.plans.logical.Range)
       |     if ($nextBatchTodo == 0) break;
       |   }
       |   $numOutput.add($nextBatchTodo);
-      |   $numGenerated.add($nextBatchTodo);
+      |   $inputMetrics.incRecordsRead($nextBatchTodo);
       |
       |   $batchEnd += $nextBatchTodo * ${step}L;
       | }
@@ -469,7 +472,6 @@ case class RangeExec(range: 
org.apache.spark.sql.catalyst.plans.logical.Range)
 
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
-    val numGeneratedRows = longMetric("numGeneratedRows")
     sqlContext
       .sparkContext
       .parallelize(0 until numSlices, numSlices)
@@ -488,10 +490,12 @@ case class RangeExec(range: 
org.apache.spark.sql.catalyst.plans.logical.Range)
         val safePartitionEnd = getSafeMargin(partitionEnd)
         val rowSize = UnsafeRow.calculateBitSetWidthInBytes(1) + 
LongType.defaultSize
         val unsafeRow = UnsafeRow.createFromByteArray(rowSize, 1)
+        val taskContext = TaskContext.get()
 
         val iter = new Iterator[InternalRow] {
           private[this] var number: Long = safePartitionStart
           private[this] var overflow: Boolean = false
+          private[this] val inputMetrics = 
taskContext.taskMetrics().inputMetrics
 
           override def hasNext =
             if (!overflow) {
@@ -513,12 +517,12 @@ case class RangeExec(range: 
org.apache.spark.sql.catalyst.plans.logical.Range)
             }
 
             numOutputRows += 1
-            numGeneratedRows += 1
+            inputMetrics.incRecordsRead(1)
             unsafeRow.setLong(0, ret)
             unsafeRow
           }
         }
-        new InterruptibleIterator(TaskContext.get(), iter)
+        new InterruptibleIterator(taskContext, iter)
       }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b486ffc8/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala
deleted file mode 100644
index ddd7a03..0000000
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala
+++ /dev/null
@@ -1,131 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution
-
-import java.io.File
-
-import org.scalatest.concurrent.Eventually
-
-import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
-import org.apache.spark.sql.{DataFrame, QueryTest}
-import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.util.Utils
-
-class InputGeneratedOutputMetricsSuite extends QueryTest with SharedSQLContext 
with Eventually {
-
-  test("Range query input/output/generated metrics") {
-    val numRows = 150L
-    val numSelectedRows = 100L
-    val res = MetricsTestHelper.runAndGetMetrics(spark.range(0, numRows, 1).
-      filter(x => x < numSelectedRows).toDF())
-
-    assert(res.recordsRead.sum === 0)
-    assert(res.shuffleRecordsRead.sum === 0)
-    assert(res.generatedRows === numRows :: Nil)
-    assert(res.outputRows === numSelectedRows :: numRows :: Nil)
-  }
-
-  test("Input/output/generated metrics with repartitioning") {
-    val numRows = 100L
-    val res = MetricsTestHelper.runAndGetMetrics(
-      spark.range(0, numRows).repartition(3).filter(x => x % 5 == 0).toDF())
-
-    assert(res.recordsRead.sum === 0)
-    assert(res.shuffleRecordsRead.sum === numRows)
-    assert(res.generatedRows === numRows :: Nil)
-    assert(res.outputRows === 20 :: numRows :: Nil)
-  }
-
-  test("Input/output/generated metrics with more repartitioning") {
-    withTempDir { tempDir =>
-      val dir = new File(tempDir, "pqS").getCanonicalPath
-
-      spark.range(10).write.parquet(dir)
-      spark.read.parquet(dir).createOrReplaceTempView("pqS")
-
-      val res = MetricsTestHelper.runAndGetMetrics(
-        spark.range(0, 30).repartition(3).crossJoin(sql("select * from 
pqS")).repartition(2)
-            .toDF()
-      )
-
-      assert(res.recordsRead.sum == 10)
-      assert(res.shuffleRecordsRead.sum == 3 * 10 + 2 * 150)
-      assert(res.generatedRows == 30 :: Nil)
-      assert(res.outputRows == 10 :: 30 :: 300 :: Nil)
-    }
-  }
-}
-
-object MetricsTestHelper {
-  case class AggregatedMetricsResult(
-      recordsRead: List[Long],
-      shuffleRecordsRead: List[Long],
-      generatedRows: List[Long],
-      outputRows: List[Long])
-
-  private[this] def extractMetricValues(
-      df: DataFrame,
-      metricValues: Map[Long, String],
-      metricName: String): List[Long] = {
-    df.queryExecution.executedPlan.collect {
-      case plan if plan.metrics.contains(metricName) =>
-        metricValues(plan.metrics(metricName).id).toLong
-    }.toList.sorted
-  }
-
-  def runAndGetMetrics(df: DataFrame, useWholeStageCodeGen: Boolean = false):
-      AggregatedMetricsResult = {
-    val spark = df.sparkSession
-    val sparkContext = spark.sparkContext
-
-    var recordsRead = List[Long]()
-    var shuffleRecordsRead = List[Long]()
-    val listener = new SparkListener() {
-      override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
-        if (taskEnd.taskMetrics != null) {
-          recordsRead = taskEnd.taskMetrics.inputMetrics.recordsRead ::
-            recordsRead
-          shuffleRecordsRead = 
taskEnd.taskMetrics.shuffleReadMetrics.recordsRead ::
-            shuffleRecordsRead
-        }
-      }
-    }
-
-    val oldExecutionIds = spark.sharedState.listener.executionIdToData.keySet
-
-    val prevUseWholeStageCodeGen =
-      spark.sessionState.conf.getConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED)
-    try {
-      spark.sessionState.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, 
useWholeStageCodeGen)
-      sparkContext.listenerBus.waitUntilEmpty(10000)
-      sparkContext.addSparkListener(listener)
-      df.collect()
-      sparkContext.listenerBus.waitUntilEmpty(10000)
-    } finally {
-      spark.sessionState.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, 
prevUseWholeStageCodeGen)
-    }
-
-    val executionId = 
spark.sharedState.listener.executionIdToData.keySet.diff(oldExecutionIds).head
-    val metricValues = 
spark.sharedState.listener.getExecutionMetrics(executionId)
-    val outputRes = extractMetricValues(df, metricValues, "numOutputRows")
-    val generatedRes = extractMetricValues(df, metricValues, 
"numGeneratedRows")
-
-    AggregatedMetricsResult(recordsRead.sorted, shuffleRecordsRead.sorted, 
generatedRes, outputRes)
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/b486ffc8/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 229d881..2ce7db6 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -17,7 +17,12 @@
 
 package org.apache.spark.sql.execution.metric
 
+import java.io.File
+
+import scala.collection.mutable.HashMap
+
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
 import org.apache.spark.sql.execution.SparkPlanInfo
@@ -309,4 +314,103 @@ class SQLMetricsSuite extends SparkFunSuite with 
SharedSQLContext {
     assert(metricInfoDeser.metadata === 
Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER))
   }
 
+  test("range metrics") {
+    val res1 = InputOutputMetricsHelper.run(
+      spark.range(30).filter(x => x % 3 == 0).toDF()
+    )
+    assert(res1 === (30L, 0L, 30L) :: Nil)
+
+    val res2 = InputOutputMetricsHelper.run(
+      spark.range(150).repartition(4).filter(x => x < 10).toDF()
+    )
+    assert(res2 === (150L, 0L, 150L) :: (0L, 150L, 10L) :: Nil)
+
+    withTempDir { tempDir =>
+      val dir = new File(tempDir, "pqS").getCanonicalPath
+
+      spark.range(10).write.parquet(dir)
+      spark.read.parquet(dir).createOrReplaceTempView("pqS")
+
+      val res3 = InputOutputMetricsHelper.run(
+        spark.range(30).repartition(3).crossJoin(sql("select * from 
pqS")).repartition(2).toDF()
+      )
+      // The query above is executed in the following stages:
+      //   1. sql("select * from pqS")    => (10, 0, 10)
+      //   2. range(30)                   => (30, 0, 30)
+      //   3. crossJoin(...) of 1. and 2. => (0, 30, 300)
+      //   4. shuffle & return results    => (0, 300, 0)
+      assert(res3 === (10L, 0L, 10L) :: (30L, 0L, 30L) :: (0L, 30L, 300L) :: 
(0L, 300L, 0L) :: Nil)
+    }
+  }
+}
+
+object InputOutputMetricsHelper {
+  private class InputOutputMetricsListener extends SparkListener {
+    private case class MetricsResult(
+        var recordsRead: Long = 0L,
+        var shuffleRecordsRead: Long = 0L,
+        var sumMaxOutputRows: Long = 0L)
+
+    private[this] val stageIdToMetricsResult = HashMap.empty[Int, 
MetricsResult]
+
+    def reset(): Unit = {
+      stageIdToMetricsResult.clear()
+    }
+
+    /**
+     * Return a list of recorded metrics aggregated per stage.
+     *
+     * The list is sorted in the ascending order on the stageId.
+     * For each recorded stage, the following tuple is returned:
+     *  - sum of inputMetrics.recordsRead for all the tasks in the stage
+     *  - sum of shuffleReadMetrics.recordsRead for all the tasks in the stage
+     *  - sum of the highest values of "number of output rows" metric for all 
the tasks in the stage
+     */
+    def getResults(): List[(Long, Long, Long)] = {
+      stageIdToMetricsResult.keySet.toList.sorted.map { stageId =>
+        val res = stageIdToMetricsResult(stageId)
+        (res.recordsRead, res.shuffleRecordsRead, res.sumMaxOutputRows)
+      }
+    }
+
+    override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized 
{
+      val res = stageIdToMetricsResult.getOrElseUpdate(taskEnd.stageId, 
MetricsResult())
+
+      res.recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead
+      res.shuffleRecordsRead += 
taskEnd.taskMetrics.shuffleReadMetrics.recordsRead
+
+      var maxOutputRows = 0L
+      for (accum <- taskEnd.taskMetrics.externalAccums) {
+        val info = accum.toInfo(Some(accum.value), None)
+        if (info.name.toString.contains("number of output rows")) {
+          info.update match {
+            case Some(n: Number) =>
+              if (n.longValue() > maxOutputRows) {
+                maxOutputRows = n.longValue()
+              }
+            case _ => // Ignore.
+          }
+        }
+      }
+      res.sumMaxOutputRows += maxOutputRows
+    }
+  }
+
+  // Run df.collect() and return aggregated metrics for each stage.
+  def run(df: DataFrame): List[(Long, Long, Long)] = {
+    val spark = df.sparkSession
+    val sparkContext = spark.sparkContext
+    val listener = new InputOutputMetricsListener()
+    sparkContext.addSparkListener(listener)
+
+    try {
+      sparkContext.listenerBus.waitUntilEmpty(5000)
+      listener.reset()
+      df.collect()
+      sparkContext.listenerBus.waitUntilEmpty(5000)
+    } finally {
+      sparkContext.removeSparkListener(listener)
+    }
+    listener.getResults()
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b486ffc8/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 92d3e95..5463728 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.DataSourceScanExec
 import org.apache.spark.sql.execution.command.ExplainCommand
 import org.apache.spark.sql.execution.datasources.LogicalRelation
 import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, 
JDBCRelation, JdbcUtils}
-import org.apache.spark.sql.execution.MetricsTestHelper
+import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
@@ -917,13 +917,10 @@ class JDBCSuite extends SparkFunSuite
     assert(e2.contains("User specified schema not supported with `jdbc`"))
   }
 
-  test("Input/generated/output metrics on JDBC") {
+  test("Checking metrics correctness with JDBC") {
     val foobarCnt = spark.table("foobar").count()
-    val res = MetricsTestHelper.runAndGetMetrics(sql("SELECT * FROM 
foobar").toDF())
-    assert(res.recordsRead === foobarCnt :: Nil)
-    assert(res.shuffleRecordsRead.sum === 0)
-    assert(res.generatedRows.isEmpty)
-    assert(res.outputRows === foobarCnt :: Nil)
+    val res = InputOutputMetricsHelper.run(sql("SELECT * FROM foobar").toDF())
+    assert(res === (foobarCnt, 0L, foobarCnt) :: Nil)
   }
 
   test("SPARK-19318: Connection properties keys should be case-sensitive.") {

http://git-wip-us.apache.org/repos/asf/spark/blob/b486ffc8/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
index 35c41b5..7803ac3 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.execution
 
 import org.scalatest.BeforeAndAfterAll
 
-import org.apache.spark.sql.execution.MetricsTestHelper
+import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper
 import org.apache.spark.sql.hive.test.TestHive
 
 /**
@@ -49,21 +49,15 @@ class HiveSerDeSuite extends HiveComparisonTest with 
BeforeAndAfterAll {
 
   createQueryTest("Read Partitioned with AvroSerDe", "SELECT * FROM 
episodes_part")
 
-  test("Test input/generated/output metrics") {
+  test("Checking metrics correctness") {
     import TestHive._
 
     val episodesCnt = sql("select * from episodes").count()
-    val episodesRes = MetricsTestHelper.runAndGetMetrics(sql("select * from 
episodes").toDF())
-    assert(episodesRes.recordsRead === episodesCnt :: Nil)
-    assert(episodesRes.shuffleRecordsRead.sum === 0)
-    assert(episodesRes.generatedRows.isEmpty)
-    assert(episodesRes.outputRows === episodesCnt :: Nil)
+    val episodesRes = InputOutputMetricsHelper.run(sql("select * from 
episodes").toDF())
+    assert(episodesRes === (episodesCnt, 0L, episodesCnt) :: Nil)
 
     val serdeinsCnt = sql("select * from serdeins").count()
-    val serdeinsRes = MetricsTestHelper.runAndGetMetrics(sql("select * from 
serdeins").toDF())
-    assert(serdeinsRes.recordsRead === serdeinsCnt :: Nil)
-    assert(serdeinsRes.shuffleRecordsRead.sum === 0)
-    assert(serdeinsRes.generatedRows.isEmpty)
-    assert(serdeinsRes.outputRows === serdeinsCnt :: Nil)
+    val serdeinsRes = InputOutputMetricsHelper.run(sql("select * from 
serdeins").toDF())
+    assert(serdeinsRes === (serdeinsCnt, 0L, serdeinsCnt) :: Nil)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to