Repository: spark Updated Branches: refs/heads/master 457dc9ccb -> 0e80ecae3
[SPARK-21100][SQL][FOLLOWUP] cleanup code and add more comments for Dataset.summary ## What changes were proposed in this pull request? Some code cleanup and adding comments to make the code more readable. Changed the way to generate result rows, to be more clear. ## How was this patch tested? existing tests Author: Wenchen Fan <wenc...@databricks.com> Closes #18570 from cloud-fan/summary. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0e80ecae Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0e80ecae Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0e80ecae Branch: refs/heads/master Commit: 0e80ecae300f3e2033419b2d98da8bf092c105bb Parents: 457dc9c Author: Wenchen Fan <wenc...@databricks.com> Authored: Sun Jul 9 22:53:27 2017 -0700 Committer: Xiao Li <gatorsm...@gmail.com> Committed: Sun Jul 9 22:53:27 2017 -0700 ---------------------------------------------------------------------- .../scala/org/apache/spark/sql/Dataset.scala | 9 -- .../sql/execution/stat/StatFunctions.scala | 129 ++++++++----------- .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- 3 files changed, 56 insertions(+), 84 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/0e80ecae/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 5326b45..dfb5119 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -224,15 +224,6 @@ class Dataset[T] private[sql]( } } - private[sql] def aggregatableColumns: Seq[Expression] = { - schema.fields - .filter(f => f.dataType.isInstanceOf[NumericType] || f.dataType.isInstanceOf[StringType]) - .map { n => - queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver) - .get - } - } - /** * Compose the string representing rows for output * http://git-wip-us.apache.org/repos/asf/spark/blob/0e80ecae/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 436e18f..a75cfb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.execution.stat +import java.util.Locale + import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.expressions.{Cast, CreateArray, Expression, GenericInternalRow, Literal} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Expression, GenericInternalRow, GetArrayItem, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.catalyst.util.{usePrettyExpression, QuantileSummaries} +import org.apache.spark.sql.catalyst.util.QuantileSummaries import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -228,90 +231,68 @@ object StatFunctions extends Logging { val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max") val selectedStatistics = if (statistics.nonEmpty) statistics else defaultStatistics - val hasPercentiles = selectedStatistics.exists(_.endsWith("%")) - val (percentiles, percentileNames, remainingAggregates) = if (hasPercentiles) { - val (pStrings, rest) = selectedStatistics.partition(a => a.endsWith("%")) - val percentiles = pStrings.map { p => - try { - p.stripSuffix("%").toDouble / 100.0 - } catch { - case e: NumberFormatException => - throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e) - } + val percentiles = selectedStatistics.filter(a => a.endsWith("%")).map { p => + try { + p.stripSuffix("%").toDouble / 100.0 + } catch { + case e: NumberFormatException => + throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e) } - require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]") - (percentiles, pStrings, rest) - } else { - (Seq(), Seq(), selectedStatistics) } + require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]") - - // The list of summary statistics to compute, in the form of expressions. - val availableStatistics = Map[String, Expression => Expression]( - "count" -> ((child: Expression) => Count(child).toAggregateExpression()), - "mean" -> ((child: Expression) => Average(child).toAggregateExpression()), - "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()), - "min" -> ((child: Expression) => Min(child).toAggregateExpression()), - "max" -> ((child: Expression) => Max(child).toAggregateExpression())) - - val statisticFns = remainingAggregates.map { agg => - require(availableStatistics.contains(agg), s"$agg is not a recognised statistic") - agg -> availableStatistics(agg) - } - - def percentileAgg(child: Expression): Expression = - new ApproximatePercentile(child, CreateArray(percentiles.map(Literal(_)))) - .toAggregateExpression() - - val outputCols = ds.aggregatableColumns.map(usePrettyExpression(_).sql).toList - - val ret: Seq[Row] = if (outputCols.nonEmpty) { - var aggExprs = statisticFns.toList.flatMap { case (_, colToAgg) => - outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c)) - } - if (hasPercentiles) { - aggExprs = outputCols.map(c => Column(percentileAgg(Column(c).expr)).as(c)) ++ aggExprs + var percentileIndex = 0 + val statisticFns = selectedStatistics.map { stats => + if (stats.endsWith("%")) { + val index = percentileIndex + percentileIndex += 1 + (child: Expression) => + GetArrayItem( + new ApproximatePercentile(child, Literal.create(percentiles)).toAggregateExpression(), + Literal(index)) + } else { + stats.toLowerCase(Locale.ROOT) match { + case "count" => (child: Expression) => Count(child).toAggregateExpression() + case "mean" => (child: Expression) => Average(child).toAggregateExpression() + case "stddev" => (child: Expression) => StddevSamp(child).toAggregateExpression() + case "min" => (child: Expression) => Min(child).toAggregateExpression() + case "max" => (child: Expression) => Max(child).toAggregateExpression() + case _ => throw new IllegalArgumentException(s"$stats is not a recognised statistic") + } } + } - val row = ds.groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq + val selectedCols = ds.logicalPlan.output + .filter(a => a.dataType.isInstanceOf[NumericType] || a.dataType.isInstanceOf[StringType]) - // Pivot the data so each summary is one row - val grouped: Seq[Seq[Any]] = row.grouped(outputCols.size).toSeq + val aggExprs = statisticFns.flatMap { func => + selectedCols.map(c => Column(Cast(func(c), StringType)).as(c.name)) + } - val basicStats = if (hasPercentiles) grouped.tail else grouped + // If there is no selected columns, we don't need to run this aggregate, so make it a lazy val. + lazy val aggResult = ds.select(aggExprs: _*).queryExecution.toRdd.collect().head - val rows = basicStats.zip(statisticFns).map { case (aggregation, (statistic, _)) => - Row(statistic :: aggregation.toList: _*) - } + // We will have one row for each selected statistic in the result. + val result = Array.fill[InternalRow](selectedStatistics.length) { + // each row has the statistic name, and statistic values of each selected column. + new GenericInternalRow(selectedCols.length + 1) + } - if (hasPercentiles) { - def nullSafeString(x: Any) = if (x == null) null else x.toString - val percentileRows = grouped.head - .map { - case a: Seq[Any] => a - case _ => Seq.fill(percentiles.length)(null: Any) - } - .transpose - .zip(percentileNames) - .map { case (values: Seq[Any], name) => - Row(name :: values.map(nullSafeString).toList: _*) - } - (rows ++ percentileRows) - .sortWith((left, right) => - selectedStatistics.indexOf(left(0)) < selectedStatistics.indexOf(right(0))) - } else { - rows + var rowIndex = 0 + while (rowIndex < result.length) { + val statsName = selectedStatistics(rowIndex) + result(rowIndex).update(0, UTF8String.fromString(statsName)) + for (colIndex <- selectedCols.indices) { + val statsValue = aggResult.getUTF8String(rowIndex * selectedCols.length + colIndex) + result(rowIndex).update(colIndex + 1, statsValue) } - } else { - // If there are no output columns, just output a single column that contains the stats. - selectedStatistics.map(Row(_)) + rowIndex += 1 } // All columns are string type - val schema = StructType( - StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes - // `toArray` forces materialization to make the seq serializable - Dataset.ofRows(ds.sparkSession, LocalRelation.fromExternalRows(schema, ret.toArray.toSeq)) - } + val output = AttributeReference("summary", StringType)() +: + selectedCols.map(c => AttributeReference(c.name, StringType)()) + Dataset.ofRows(ds.sparkSession, LocalRelation(output, result)) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/0e80ecae/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 2c7051b..b2219b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -770,7 +770,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val fooE = intercept[IllegalArgumentException] { person2.summary("foo") } - assert(fooE.getMessage === "requirement failed: foo is not a recognised statistic") + assert(fooE.getMessage === "foo is not a recognised statistic") val parseE = intercept[IllegalArgumentException] { person2.summary("foo%") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org