This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 6a0713a141f [SPARK-40880][SQL] Reimplement `summary` with dataframe
operations
6a0713a141f is described below
commit 6a0713a141fa98d83029d8388508cbbc40fd554e
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Oct 24 10:58:13 2022 +0900
[SPARK-40880][SQL] Reimplement `summary` with dataframe operations
### What changes were proposed in this pull request?
Reimplement `summary` with dataframe operations
### Why are the changes needed?
1, do not truncate the sql plan any more;
2, enable sql optimization like column pruning:
```
scala> val df = spark.range(0, 3, 1, 10).withColumn("value", lit("str"))
df: org.apache.spark.sql.DataFrame = [id: bigint, value: string]
scala> df.summary("max", "50%").show
+-------+---+-----+
|summary| id|value|
+-------+---+-----+
| max| 2| str|
| 50%| 1| null|
+-------+---+-----+
scala> df.summary("max", "50%").select("id").show
+---+
| id|
+---+
| 2|
| 1|
+---+
scala> df.summary("max", "50%").select("id").queryExecution.optimizedPlan
res4: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan =
Project [element_at(id#367, summary#376, None, false) AS id#371]
+- Generate explode([max,50%]), false, [summary#376]
+- Aggregate [map(max, cast(max(id#153L) as string), 50%,
cast(percentile_approx(id#153L, [0.5], 10000, 0, 0)[0] as string)) AS id#367]
+- Range (0, 3, step=1, splits=Some(10))
```
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
existing UTs and manually check
Closes #38346 from zhengruifeng/sql_stat_summary.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../spark/sql/execution/stat/StatFunctions.scala | 122 ++++++++++-----------
1 file changed, 59 insertions(+), 63 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index 484be76b991..508d2c64d09 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -21,11 +21,10 @@ import java.util.Locale
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast,
EvalMode, Expression, GenericInternalRow, GetArrayItem, Literal}
+import org.apache.spark.sql.catalyst.expressions.{Cast, ElementAt, EvalMode,
GenericInternalRow}
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.catalyst.util.{GenericArrayData, QuantileSummaries}
+import org.apache.spark.sql.catalyst.util.QuantileSummaries
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
@@ -199,9 +198,11 @@ object StatFunctions extends Logging {
/** Calculate selected summary statistics for a dataset */
def summary(ds: Dataset[_], statistics: Seq[String]): DataFrame = {
-
- val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%",
"50%", "75%", "max")
- val selectedStatistics = if (statistics.nonEmpty) statistics else
defaultStatistics
+ val selectedStatistics = if (statistics.nonEmpty) {
+ statistics.toArray
+ } else {
+ Array("count", "mean", "stddev", "min", "25%", "50%", "75%", "max")
+ }
val percentiles = selectedStatistics.filter(a => a.endsWith("%")).map { p
=>
try {
@@ -213,71 +214,66 @@ object StatFunctions extends Logging {
}
require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in
the range [0, 1]")
- def castAsDoubleIfNecessary(e: Expression): Expression = if (e.dataType ==
StringType) {
- Cast(e, DoubleType, evalMode = EvalMode.TRY)
- } else {
- e
- }
- var percentileIndex = 0
- val statisticFns = selectedStatistics.map { stats =>
- if (stats.endsWith("%")) {
- val index = percentileIndex
- percentileIndex += 1
- (child: Expression) =>
- GetArrayItem(
- new ApproximatePercentile(castAsDoubleIfNecessary(child),
- Literal(new GenericArrayData(percentiles), ArrayType(DoubleType,
false)))
- .toAggregateExpression(),
- Literal(index))
- } else {
- stats.toLowerCase(Locale.ROOT) match {
- case "count" => (child: Expression) =>
Count(child).toAggregateExpression()
- case "count_distinct" => (child: Expression) =>
- Count(child).toAggregateExpression(isDistinct = true)
- case "approx_count_distinct" => (child: Expression) =>
- HyperLogLogPlusPlus(child).toAggregateExpression()
- case "mean" => (child: Expression) =>
- Average(castAsDoubleIfNecessary(child)).toAggregateExpression()
- case "stddev" => (child: Expression) =>
- StddevSamp(castAsDoubleIfNecessary(child)).toAggregateExpression()
- case "min" => (child: Expression) =>
Min(child).toAggregateExpression()
- case "max" => (child: Expression) =>
Max(child).toAggregateExpression()
- case _ => throw
QueryExecutionErrors.statisticNotRecognizedError(stats)
+ var mapColumns = Seq.empty[Column]
+ var columnNames = Seq.empty[String]
+
+ ds.schema.fields.foreach { field =>
+ if (field.dataType.isInstanceOf[NumericType] ||
field.dataType.isInstanceOf[StringType]) {
+ val column = col(field.name)
+ var casted = column
+ if (field.dataType.isInstanceOf[StringType]) {
+ casted = new Column(Cast(column.expr, DoubleType, evalMode =
EvalMode.TRY))
}
- }
- }
- val selectedCols = ds.logicalPlan.output
- .filter(a => a.dataType.isInstanceOf[NumericType] ||
a.dataType.isInstanceOf[StringType])
+ val percentilesCol = if (percentiles.nonEmpty) {
+ percentile_approx(casted, lit(percentiles),
+ lit(ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY))
+ } else null
- val aggExprs = statisticFns.flatMap { func =>
- selectedCols.map(c => Column(Cast(func(c), StringType)).as(c.name))
- }
+ var aggColumns = Seq.empty[Column]
+ var percentileIndex = 0
+ selectedStatistics.foreach { stats =>
+ aggColumns :+= lit(stats)
- // If there is no selected columns, we don't need to run this aggregate,
so make it a lazy val.
- lazy val aggResult = ds.select(aggExprs:
_*).queryExecution.toRdd.collect().head
+ stats.toLowerCase(Locale.ROOT) match {
+ case "count" => aggColumns :+= count(column)
- // We will have one row for each selected statistic in the result.
- val result = Array.fill[InternalRow](selectedStatistics.length) {
- // each row has the statistic name, and statistic values of each
selected column.
- new GenericInternalRow(selectedCols.length + 1)
- }
+ case "count_distinct" => aggColumns :+= count_distinct(column)
+
+ case "approx_count_distinct" => aggColumns :+=
approx_count_distinct(column)
- var rowIndex = 0
- while (rowIndex < result.length) {
- val statsName = selectedStatistics(rowIndex)
- result(rowIndex).update(0, UTF8String.fromString(statsName))
- for (colIndex <- selectedCols.indices) {
- val statsValue = aggResult.getUTF8String(rowIndex *
selectedCols.length + colIndex)
- result(rowIndex).update(colIndex + 1, statsValue)
+ case "mean" => aggColumns :+= avg(casted)
+
+ case "stddev" => aggColumns :+= stddev(casted)
+
+ case "min" => aggColumns :+= min(column)
+
+ case "max" => aggColumns :+= max(column)
+
+ case percentile if percentile.endsWith("%") =>
+ aggColumns :+= get(percentilesCol, lit(percentileIndex))
+ percentileIndex += 1
+
+ case _ => throw
QueryExecutionErrors.statisticNotRecognizedError(stats)
+ }
+ }
+
+ // map { "count" -> "1024", "min" -> "1.0", ... }
+ mapColumns :+= map(aggColumns.map(_.cast(StringType)):
_*).as(field.name)
+ columnNames :+= field.name
}
- rowIndex += 1
}
- // All columns are string type
- val output = AttributeReference("summary", StringType)() +:
- selectedCols.map(c => AttributeReference(c.name, StringType)())
-
- Dataset.ofRows(ds.sparkSession, LocalRelation(output, result))
+ if (mapColumns.isEmpty) {
+ ds.sparkSession.createDataFrame(selectedStatistics.map(Tuple1.apply))
+ .withColumnRenamed("_1", "summary")
+ } else {
+ val valueColumns = columnNames.map { columnName =>
+ new Column(ElementAt(col(columnName).expr,
col("summary").expr)).as(columnName)
+ }
+ ds.select(mapColumns: _*)
+ .withColumn("summary", explode(lit(selectedStatistics)))
+ .select(Array(col("summary")) ++ valueColumns: _*)
+ }
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]