Repository: spark Updated Branches: refs/heads/master be317d4a9 -> bf5496dbd
http://git-wip-us.apache.org/repos/asf/spark/blob/bf5496db/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 5755c00..7bf9225 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -19,200 +19,106 @@ package org.apache.spark.sql.execution.metric import java.text.NumberFormat -import org.apache.spark.{Accumulable, AccumulableParam, Accumulators, SparkContext} +import org.apache.spark.{NewAccumulator, SparkContext} import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.util.Utils -/** - * Create a layer for specialized metric. We cannot add `@specialized` to - * `Accumulable/AccumulableParam` because it will break Java source compatibility. - * - * An implementation of SQLMetric should override `+=` and `add` to avoid boxing. - */ -private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T]( - name: String, - val param: SQLMetricParam[R, T]) extends Accumulable[R, T](param.zero, param, Some(name)) { - // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later - override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = { - new AccumulableInfo(id, Some(name), update, value, true, countFailedValues, - Some(SQLMetrics.ACCUM_IDENTIFIER)) - } - - def reset(): Unit = { - this.value = param.zero - } -} - -/** - * Create a layer for specialized metric. We cannot add `@specialized` to - * `Accumulable/AccumulableParam` because it will break Java source compatibility. - */ -private[sql] trait SQLMetricParam[R <: SQLMetricValue[T], T] extends AccumulableParam[R, T] { - - /** - * A function that defines how we aggregate the final accumulator results among all tasks, - * and represent it in string for a SQL physical operator. - */ - val stringValue: Seq[T] => String - - def zero: R -} +class SQLMetric(val metricType: String, initValue: Long = 0L) extends NewAccumulator[Long, Long] { + // This is a workaround for SPARK-11013. + // We may use -1 as initial value of the accumulator, if the accumulator is valid, we will + // update it at the end of task and the value will be at least 0. Then we can filter out the -1 + // values before calculate max, min, etc. + private[this] var _value = initValue -/** - * Create a layer for specialized metric. We cannot add `@specialized` to - * `Accumulable/AccumulableParam` because it will break Java source compatibility. - */ -private[sql] trait SQLMetricValue[T] extends Serializable { + override def copyAndReset(): SQLMetric = new SQLMetric(metricType, initValue) - def value: T - - override def toString: String = value.toString -} - -/** - * A wrapper of Long to avoid boxing and unboxing when using Accumulator - */ -private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetricValue[Long] { - - def add(incr: Long): LongSQLMetricValue = { - _value += incr - this + override def merge(other: NewAccumulator[Long, Long]): Unit = other match { + case o: SQLMetric => _value += o.localValue + case _ => throw new UnsupportedOperationException( + s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") } - // Although there is a boxing here, it's fine because it's only called in SQLListener - override def value: Long = _value - - // Needed for SQLListenerSuite - override def equals(other: Any): Boolean = other match { - case o: LongSQLMetricValue => value == o.value - case _ => false - } + override def isZero(): Boolean = _value == initValue - override def hashCode(): Int = _value.hashCode() -} + override def add(v: Long): Unit = _value += v -/** - * A specialized long Accumulable to avoid boxing and unboxing when using Accumulator's - * `+=` and `add`. - */ -private[sql] class LongSQLMetric private[metric](name: String, param: LongSQLMetricParam) - extends SQLMetric[LongSQLMetricValue, Long](name, param) { + def +=(v: Long): Unit = _value += v - override def +=(term: Long): Unit = { - localValue.add(term) - } + override def localValue: Long = _value - override def add(term: Long): Unit = { - localValue.add(term) + // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later + private[spark] override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = { + new AccumulableInfo(id, name, update, value, true, true, Some(SQLMetrics.ACCUM_IDENTIFIER)) } -} - -private class LongSQLMetricParam(val stringValue: Seq[Long] => String, initialValue: Long) - extends SQLMetricParam[LongSQLMetricValue, Long] { - - override def addAccumulator(r: LongSQLMetricValue, t: Long): LongSQLMetricValue = r.add(t) - override def addInPlace(r1: LongSQLMetricValue, r2: LongSQLMetricValue): LongSQLMetricValue = - r1.add(r2.value) - - override def zero(initialValue: LongSQLMetricValue): LongSQLMetricValue = zero - - override def zero: LongSQLMetricValue = new LongSQLMetricValue(initialValue) + def reset(): Unit = _value = initValue } -private object LongSQLMetricParam - extends LongSQLMetricParam(x => NumberFormat.getInstance().format(x.sum), 0L) - -private object StatisticsBytesSQLMetricParam extends LongSQLMetricParam( - (values: Seq[Long]) => { - // This is a workaround for SPARK-11013. - // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update - // it at the end of task and the value will be at least 0. - val validValues = values.filter(_ >= 0) - val Seq(sum, min, med, max) = { - val metric = if (validValues.length == 0) { - Seq.fill(4)(0L) - } else { - val sorted = validValues.sorted - Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) - } - metric.map(Utils.bytesToString) - } - s"\n$sum ($min, $med, $max)" - }, -1L) - -private object StatisticsTimingSQLMetricParam extends LongSQLMetricParam( - (values: Seq[Long]) => { - // This is a workaround for SPARK-11013. - // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update - // it at the end of task and the value will be at least 0. - val validValues = values.filter(_ >= 0) - val Seq(sum, min, med, max) = { - val metric = if (validValues.length == 0) { - Seq.fill(4)(0L) - } else { - val sorted = validValues.sorted - Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) - } - metric.map(Utils.msDurationToString) - } - s"\n$sum ($min, $med, $max)" - }, -1L) private[sql] object SQLMetrics { - // Identifier for distinguishing SQL metrics from other accumulators private[sql] val ACCUM_IDENTIFIER = "sql" - private def createLongMetric( - sc: SparkContext, - name: String, - param: LongSQLMetricParam): LongSQLMetric = { - val acc = new LongSQLMetric(name, param) - // This is an internal accumulator so we need to register it explicitly. - Accumulators.register(acc) - sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) - acc - } + private[sql] val SUM_METRIC = "sum" + private[sql] val SIZE_METRIC = "size" + private[sql] val TIMING_METRIC = "timing" - def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = { - createLongMetric(sc, name, LongSQLMetricParam) + def createMetric(sc: SparkContext, name: String): SQLMetric = { + val acc = new SQLMetric(SUM_METRIC) + acc.register(sc, name = Some(name), countFailedValues = true) + acc } /** * Create a metric to report the size information (including total, min, med, max) like data size, * spill size, etc. */ - def createSizeMetric(sc: SparkContext, name: String): LongSQLMetric = { + def createSizeMetric(sc: SparkContext, name: String): SQLMetric = { // The final result of this metric in physical operator UI may looks like: // data size total (min, med, max): // 100GB (100MB, 1GB, 10GB) - createLongMetric(sc, s"$name total (min, med, max)", StatisticsBytesSQLMetricParam) + val acc = new SQLMetric(SIZE_METRIC, -1) + acc.register(sc, name = Some(s"$name total (min, med, max)"), countFailedValues = true) + acc } - def createTimingMetric(sc: SparkContext, name: String): LongSQLMetric = { + def createTimingMetric(sc: SparkContext, name: String): SQLMetric = { // The final result of this metric in physical operator UI may looks like: // duration(min, med, max): // 5s (800ms, 1s, 2s) - createLongMetric(sc, s"$name total (min, med, max)", StatisticsTimingSQLMetricParam) - } - - def getMetricParam(metricParamName: String): SQLMetricParam[SQLMetricValue[Any], Any] = { - val longSQLMetricParam = Utils.getFormattedClassName(LongSQLMetricParam) - val bytesSQLMetricParam = Utils.getFormattedClassName(StatisticsBytesSQLMetricParam) - val timingsSQLMetricParam = Utils.getFormattedClassName(StatisticsTimingSQLMetricParam) - val metricParam = metricParamName match { - case `longSQLMetricParam` => LongSQLMetricParam - case `bytesSQLMetricParam` => StatisticsBytesSQLMetricParam - case `timingsSQLMetricParam` => StatisticsTimingSQLMetricParam - } - metricParam.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]] + val acc = new SQLMetric(TIMING_METRIC, -1) + acc.register(sc, name = Some(s"$name total (min, med, max)"), countFailedValues = true) + acc } /** - * A metric that its value will be ignored. Use this one when we need a metric parameter but don't - * care about the value. + * A function that defines how we aggregate the final accumulator results among all tasks, + * and represent it in string for a SQL physical operator. */ - val nullLongMetric = new LongSQLMetric("null", LongSQLMetricParam) + def stringValue(metricsType: String, values: Seq[Long]): String = { + if (metricsType == SUM_METRIC) { + NumberFormat.getInstance().format(values.sum) + } else { + val strFormat: Long => String = if (metricsType == SIZE_METRIC) { + Utils.bytesToString + } else if (metricsType == TIMING_METRIC) { + Utils.msDurationToString + } else { + throw new IllegalStateException("unexpected metrics type: " + metricsType) + } + + val validValues = values.filter(_ >= 0) + val Seq(sum, min, med, max) = { + val metric = if (validValues.length == 0) { + Seq.fill(4)(0L) + } else { + val sorted = validValues.sorted + Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + } + metric.map(strFormat) + } + s"\n$sum ($min, $med, $max)" + } + } } http://git-wip-us.apache.org/repos/asf/spark/blob/bf5496db/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 5ae9e91..9118593 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -164,7 +164,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi taskEnd.taskInfo.taskId, taskEnd.stageId, taskEnd.stageAttemptId, - taskEnd.taskMetrics.accumulatorUpdates(), + taskEnd.taskMetrics.accumulators().map(a => a.toInfo(Some(a.localValue), None)), finishTask = true) } } @@ -296,7 +296,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } }.filter { case (id, _) => executionUIData.accumulatorMetrics.contains(id) } mergeAccumulatorUpdates(accumulatorUpdates, accumulatorId => - executionUIData.accumulatorMetrics(accumulatorId).metricParam) + executionUIData.accumulatorMetrics(accumulatorId).metricType) case None => // This execution has been dropped Map.empty @@ -305,11 +305,11 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi private def mergeAccumulatorUpdates( accumulatorUpdates: Seq[(Long, Any)], - paramFunc: Long => SQLMetricParam[SQLMetricValue[Any], Any]): Map[Long, String] = { + metricTypeFunc: Long => String): Map[Long, String] = { accumulatorUpdates.groupBy(_._1).map { case (accumulatorId, values) => - val param = paramFunc(accumulatorId) - (accumulatorId, - param.stringValue(values.map(_._2.asInstanceOf[SQLMetricValue[Any]].value))) + val metricType = metricTypeFunc(accumulatorId) + accumulatorId -> + SQLMetrics.stringValue(metricType, values.map(_._2.asInstanceOf[Long])) } } @@ -337,7 +337,7 @@ private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI) // Filter out accumulators that are not SQL metrics // For now we assume all SQL metrics are Long's that have been JSON serialized as String's if (a.metadata == Some(SQLMetrics.ACCUM_IDENTIFIER)) { - val newValue = new LongSQLMetricValue(a.update.map(_.toString.toLong).getOrElse(0L)) + val newValue = a.update.map(_.toString.toLong).getOrElse(0L) Some(a.copy(update = Some(newValue))) } else { None @@ -403,7 +403,7 @@ private[ui] class SQLExecutionUIData( private[ui] case class SQLPlanMetric( name: String, accumulatorId: Long, - metricParam: SQLMetricParam[SQLMetricValue[Any], Any]) + metricType: String) /** * Store all accumulatorUpdates for all tasks in a Spark stage. http://git-wip-us.apache.org/repos/asf/spark/blob/bf5496db/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 1959f1e..8f5681b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -80,8 +80,7 @@ private[sql] object SparkPlanGraph { planInfo.nodeName match { case "WholeStageCodegen" => val metrics = planInfo.metrics.map { metric => - SQLPlanMetric(metric.name, metric.accumulatorId, - SQLMetrics.getMetricParam(metric.metricParam)) + SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType) } val cluster = new SparkPlanGraphCluster( @@ -106,8 +105,7 @@ private[sql] object SparkPlanGraph { edges += SparkPlanGraphEdge(node.id, parent.id) case name => val metrics = planInfo.metrics.map { metric => - SQLPlanMetric(metric.name, metric.accumulatorId, - SQLMetrics.getMetricParam(metric.metricParam)) + SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType) } val node = new SparkPlanGraphNode( nodeIdGenerator.getAndIncrement(), planInfo.nodeName, http://git-wip-us.apache.org/repos/asf/spark/blob/bf5496db/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 4aea21e..0e6356b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -22,7 +22,7 @@ import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ -import org.apache.spark.Accumulators +import org.apache.spark.AccumulatorContext import org.apache.spark.sql.execution.RDDScanExec import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchange @@ -333,11 +333,11 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext sql("SELECT * FROM t1").count() sql("SELECT * FROM t2").count() - Accumulators.synchronized { - val accsSize = Accumulators.originals.size + AccumulatorContext.synchronized { + val accsSize = AccumulatorContext.originals.size sqlContext.uncacheTable("t1") sqlContext.uncacheTable("t2") - assert((accsSize - 2) == Accumulators.originals.size) + assert((accsSize - 2) == AccumulatorContext.originals.size) } } http://git-wip-us.apache.org/repos/asf/spark/blob/bf5496db/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 1859c6e..8de4d8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -37,8 +37,8 @@ import org.apache.spark.util.{JsonProtocol, Utils} class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ - test("LongSQLMetric should not box Long") { - val l = SQLMetrics.createLongMetric(sparkContext, "long") + test("SQLMetric should not box Long") { + val l = SQLMetrics.createMetric(sparkContext, "long") val f = () => { l += 1L l.add(1L) @@ -300,12 +300,12 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { } test("metrics can be loaded by history server") { - val metric = new LongSQLMetric("zanzibar", LongSQLMetricParam) + val metric = SQLMetrics.createMetric(sparkContext, "zanzibar") metric += 10L val metricInfo = metric.toInfo(Some(metric.localValue), None) metricInfo.update match { - case Some(v: LongSQLMetricValue) => assert(v.value === 10L) - case Some(v) => fail(s"metric value was not a LongSQLMetricValue: ${v.getClass.getName}") + case Some(v: Long) => assert(v === 10L) + case Some(v) => fail(s"metric value was not a Long: ${v.getClass.getName}") case _ => fail("metric update is missing") } assert(metricInfo.metadata === Some(SQLMetrics.ACCUM_IDENTIFIER)) http://git-wip-us.apache.org/repos/asf/spark/blob/bf5496db/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 09bd7f6..8572ed1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -21,18 +21,19 @@ import java.util.Properties import org.mockito.Mockito.{mock, when} -import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution} -import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.ui.SparkUI class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ + import org.apache.spark.AccumulatorSuite.makeInfo private def createTestDataFrame: DataFrame = { Seq( @@ -72,9 +73,11 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { private def createTaskMetrics(accumulatorUpdates: Map[Long, Long]): TaskMetrics = { val metrics = mock(classOf[TaskMetrics]) - when(metrics.accumulatorUpdates()).thenReturn(accumulatorUpdates.map { case (id, update) => - new AccumulableInfo(id, Some(""), Some(new LongSQLMetricValue(update)), - value = None, internal = true, countFailedValues = true) + when(metrics.accumulators()).thenReturn(accumulatorUpdates.map { case (id, update) => + val acc = new LongAccumulator + acc.metadata = AccumulatorMetadata(id, Some(""), true) + acc.setValue(update) + acc }.toSeq) metrics } @@ -130,16 +133,17 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( // (task id, stage id, stage attempt, accum updates) - (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()), - (1L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()) + (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)), + (1L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)) ))) checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( // (task id, stage id, stage attempt, accum updates) - (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()), - (1L, 0, 0, createTaskMetrics(accumulatorUpdates.mapValues(_ * 2)).accumulatorUpdates()) + (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)), + (1L, 0, 0, + createTaskMetrics(accumulatorUpdates.mapValues(_ * 2)).accumulators().map(makeInfo)) ))) checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 3)) @@ -149,8 +153,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( // (task id, stage id, stage attempt, accum updates) - (0L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()), - (1L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()) + (0L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)), + (1L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)) ))) checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) @@ -189,8 +193,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( // (task id, stage id, stage attempt, accum updates) - (0L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()), - (1L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()) + (0L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)), + (1L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)) ))) checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 7)) @@ -358,7 +362,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val stageSubmitted = SparkListenerStageSubmitted(stageInfo) // This task has both accumulators that are SQL metrics and accumulators that are not. // The listener should only track the ones that are actually SQL metrics. - val sqlMetric = SQLMetrics.createLongMetric(sparkContext, "beach umbrella") + val sqlMetric = SQLMetrics.createMetric(sparkContext, "beach umbrella") val nonSqlMetric = sparkContext.accumulator[Int](0, "baseball") val sqlMetricInfo = sqlMetric.toInfo(Some(sqlMetric.localValue), None) val nonSqlMetricInfo = nonSqlMetric.toInfo(Some(nonSqlMetric.localValue), None) http://git-wip-us.apache.org/repos/asf/spark/blob/bf5496db/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index eb25ea0..8a0578c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -96,7 +96,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { case w: WholeStageCodegenExec => w.child.longMetric("numOutputRows") case other => other.longMetric("numOutputRows") } - metrics += metric.value.value + metrics += metric.value } } sqlContext.listenerManager.register(listener) @@ -126,9 +126,9 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { - metrics += qe.executedPlan.longMetric("dataSize").value.value + metrics += qe.executedPlan.longMetric("dataSize").value val bottomAgg = qe.executedPlan.children(0).children(0) - metrics += bottomAgg.longMetric("dataSize").value.value + metrics += bottomAgg.longMetric("dataSize").value } } sqlContext.listenerManager.register(listener) http://git-wip-us.apache.org/repos/asf/spark/blob/bf5496db/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 007c338..b52b96a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -55,7 +55,7 @@ case class HiveTableScanExec( "Partition pruning predicates only supported for partitioned tables.") private[sql] override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) override def producedAttributes: AttributeSet = outputSet ++ AttributeSet(partitionPruningPred.flatMap(_.references)) --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
