This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 6a0713a141f [SPARK-40880][SQL] Reimplement `summary` with dataframe operations 6a0713a141f is described below commit 6a0713a141fa98d83029d8388508cbbc40fd554e Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Mon Oct 24 10:58:13 2022 +0900 [SPARK-40880][SQL] Reimplement `summary` with dataframe operations ### What changes were proposed in this pull request? Reimplement `summary` with dataframe operations ### Why are the changes needed? 1, do not truncate the sql plan any more; 2, enable sql optimization like column pruning: ``` scala> val df = spark.range(0, 3, 1, 10).withColumn("value", lit("str")) df: org.apache.spark.sql.DataFrame = [id: bigint, value: string] scala> df.summary("max", "50%").show +-------+---+-----+ |summary| id|value| +-------+---+-----+ | max| 2| str| | 50%| 1| null| +-------+---+-----+ scala> df.summary("max", "50%").select("id").show +---+ | id| +---+ | 2| | 1| +---+ scala> df.summary("max", "50%").select("id").queryExecution.optimizedPlan res4: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan = Project [element_at(id#367, summary#376, None, false) AS id#371] +- Generate explode([max,50%]), false, [summary#376] +- Aggregate [map(max, cast(max(id#153L) as string), 50%, cast(percentile_approx(id#153L, [0.5], 10000, 0, 0)[0] as string)) AS id#367] +- Range (0, 3, step=1, splits=Some(10)) ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing UTs and manually check Closes #38346 from zhengruifeng/sql_stat_summary. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../spark/sql/execution/stat/StatFunctions.scala | 122 ++++++++++----------- 1 file changed, 59 insertions(+), 63 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 484be76b991..508d2c64d09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -21,11 +21,10 @@ import java.util.Locale import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, EvalMode, Expression, GenericInternalRow, GetArrayItem, Literal} +import org.apache.spark.sql.catalyst.expressions.{Cast, ElementAt, EvalMode, GenericInternalRow} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.catalyst.util.{GenericArrayData, QuantileSummaries} +import org.apache.spark.sql.catalyst.util.QuantileSummaries import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -199,9 +198,11 @@ object StatFunctions extends Logging { /** Calculate selected summary statistics for a dataset */ def summary(ds: Dataset[_], statistics: Seq[String]): DataFrame = { - - val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max") - val selectedStatistics = if (statistics.nonEmpty) statistics else defaultStatistics + val selectedStatistics = if (statistics.nonEmpty) { + statistics.toArray + } else { + Array("count", "mean", "stddev", "min", "25%", "50%", "75%", "max") + } val percentiles = selectedStatistics.filter(a => a.endsWith("%")).map { p => try { @@ -213,71 +214,66 @@ object StatFunctions extends Logging { } require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]") - def castAsDoubleIfNecessary(e: Expression): Expression = if (e.dataType == StringType) { - Cast(e, DoubleType, evalMode = EvalMode.TRY) - } else { - e - } - var percentileIndex = 0 - val statisticFns = selectedStatistics.map { stats => - if (stats.endsWith("%")) { - val index = percentileIndex - percentileIndex += 1 - (child: Expression) => - GetArrayItem( - new ApproximatePercentile(castAsDoubleIfNecessary(child), - Literal(new GenericArrayData(percentiles), ArrayType(DoubleType, false))) - .toAggregateExpression(), - Literal(index)) - } else { - stats.toLowerCase(Locale.ROOT) match { - case "count" => (child: Expression) => Count(child).toAggregateExpression() - case "count_distinct" => (child: Expression) => - Count(child).toAggregateExpression(isDistinct = true) - case "approx_count_distinct" => (child: Expression) => - HyperLogLogPlusPlus(child).toAggregateExpression() - case "mean" => (child: Expression) => - Average(castAsDoubleIfNecessary(child)).toAggregateExpression() - case "stddev" => (child: Expression) => - StddevSamp(castAsDoubleIfNecessary(child)).toAggregateExpression() - case "min" => (child: Expression) => Min(child).toAggregateExpression() - case "max" => (child: Expression) => Max(child).toAggregateExpression() - case _ => throw QueryExecutionErrors.statisticNotRecognizedError(stats) + var mapColumns = Seq.empty[Column] + var columnNames = Seq.empty[String] + + ds.schema.fields.foreach { field => + if (field.dataType.isInstanceOf[NumericType] || field.dataType.isInstanceOf[StringType]) { + val column = col(field.name) + var casted = column + if (field.dataType.isInstanceOf[StringType]) { + casted = new Column(Cast(column.expr, DoubleType, evalMode = EvalMode.TRY)) } - } - } - val selectedCols = ds.logicalPlan.output - .filter(a => a.dataType.isInstanceOf[NumericType] || a.dataType.isInstanceOf[StringType]) + val percentilesCol = if (percentiles.nonEmpty) { + percentile_approx(casted, lit(percentiles), + lit(ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY)) + } else null - val aggExprs = statisticFns.flatMap { func => - selectedCols.map(c => Column(Cast(func(c), StringType)).as(c.name)) - } + var aggColumns = Seq.empty[Column] + var percentileIndex = 0 + selectedStatistics.foreach { stats => + aggColumns :+= lit(stats) - // If there is no selected columns, we don't need to run this aggregate, so make it a lazy val. - lazy val aggResult = ds.select(aggExprs: _*).queryExecution.toRdd.collect().head + stats.toLowerCase(Locale.ROOT) match { + case "count" => aggColumns :+= count(column) - // We will have one row for each selected statistic in the result. - val result = Array.fill[InternalRow](selectedStatistics.length) { - // each row has the statistic name, and statistic values of each selected column. - new GenericInternalRow(selectedCols.length + 1) - } + case "count_distinct" => aggColumns :+= count_distinct(column) + + case "approx_count_distinct" => aggColumns :+= approx_count_distinct(column) - var rowIndex = 0 - while (rowIndex < result.length) { - val statsName = selectedStatistics(rowIndex) - result(rowIndex).update(0, UTF8String.fromString(statsName)) - for (colIndex <- selectedCols.indices) { - val statsValue = aggResult.getUTF8String(rowIndex * selectedCols.length + colIndex) - result(rowIndex).update(colIndex + 1, statsValue) + case "mean" => aggColumns :+= avg(casted) + + case "stddev" => aggColumns :+= stddev(casted) + + case "min" => aggColumns :+= min(column) + + case "max" => aggColumns :+= max(column) + + case percentile if percentile.endsWith("%") => + aggColumns :+= get(percentilesCol, lit(percentileIndex)) + percentileIndex += 1 + + case _ => throw QueryExecutionErrors.statisticNotRecognizedError(stats) + } + } + + // map { "count" -> "1024", "min" -> "1.0", ... } + mapColumns :+= map(aggColumns.map(_.cast(StringType)): _*).as(field.name) + columnNames :+= field.name } - rowIndex += 1 } - // All columns are string type - val output = AttributeReference("summary", StringType)() +: - selectedCols.map(c => AttributeReference(c.name, StringType)()) - - Dataset.ofRows(ds.sparkSession, LocalRelation(output, result)) + if (mapColumns.isEmpty) { + ds.sparkSession.createDataFrame(selectedStatistics.map(Tuple1.apply)) + .withColumnRenamed("_1", "summary") + } else { + val valueColumns = columnNames.map { columnName => + new Column(ElementAt(col(columnName).expr, col("summary").expr)).as(columnName) + } + ds.select(mapColumns: _*) + .withColumn("summary", explode(lit(selectedStatistics))) + .select(Array(col("summary")) ++ valueColumns: _*) + } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org