This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6a0713a141f [SPARK-40880][SQL] Reimplement `summary` with dataframe 
operations
6a0713a141f is described below

commit 6a0713a141fa98d83029d8388508cbbc40fd554e
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Mon Oct 24 10:58:13 2022 +0900

    [SPARK-40880][SQL] Reimplement `summary` with dataframe operations
    
    ### What changes were proposed in this pull request?
    Reimplement `summary` with dataframe operations
    
    ### Why are the changes needed?
    1, do not truncate the sql plan any more;
    2, enable sql optimization like column pruning:
    
    ```
    scala> val df = spark.range(0, 3, 1, 10).withColumn("value", lit("str"))
    df: org.apache.spark.sql.DataFrame = [id: bigint, value: string]
    
    scala> df.summary("max", "50%").show
    +-------+---+-----+
    |summary| id|value|
    +-------+---+-----+
    |    max|  2|  str|
    |    50%|  1| null|
    +-------+---+-----+
    
    scala> df.summary("max", "50%").select("id").show
    +---+
    | id|
    +---+
    |  2|
    |  1|
    +---+
    
    scala> df.summary("max", "50%").select("id").queryExecution.optimizedPlan
    res4: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan =
    Project [element_at(id#367, summary#376, None, false) AS id#371]
    +- Generate explode([max,50%]), false, [summary#376]
       +- Aggregate [map(max, cast(max(id#153L) as string), 50%, 
cast(percentile_approx(id#153L, [0.5], 10000, 0, 0)[0] as string)) AS id#367]
          +- Range (0, 3, step=1, splits=Some(10))
    
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    existing UTs and manually check
    
    Closes #38346 from zhengruifeng/sql_stat_summary.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../spark/sql/execution/stat/StatFunctions.scala   | 122 ++++++++++-----------
 1 file changed, 59 insertions(+), 63 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index 484be76b991..508d2c64d09 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -21,11 +21,10 @@ import java.util.Locale
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, 
EvalMode, Expression, GenericInternalRow, GetArrayItem, Literal}
+import org.apache.spark.sql.catalyst.expressions.{Cast, ElementAt, EvalMode, 
GenericInternalRow}
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.catalyst.util.{GenericArrayData, QuantileSummaries}
+import org.apache.spark.sql.catalyst.util.QuantileSummaries
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
@@ -199,9 +198,11 @@ object StatFunctions extends Logging {
 
   /** Calculate selected summary statistics for a dataset */
   def summary(ds: Dataset[_], statistics: Seq[String]): DataFrame = {
-
-    val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", 
"50%", "75%", "max")
-    val selectedStatistics = if (statistics.nonEmpty) statistics else 
defaultStatistics
+    val selectedStatistics = if (statistics.nonEmpty) {
+      statistics.toArray
+    } else {
+      Array("count", "mean", "stddev", "min", "25%", "50%", "75%", "max")
+    }
 
     val percentiles = selectedStatistics.filter(a => a.endsWith("%")).map { p 
=>
       try {
@@ -213,71 +214,66 @@ object StatFunctions extends Logging {
     }
     require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in 
the range [0, 1]")
 
-    def castAsDoubleIfNecessary(e: Expression): Expression = if (e.dataType == 
StringType) {
-      Cast(e, DoubleType, evalMode = EvalMode.TRY)
-    } else {
-      e
-    }
-    var percentileIndex = 0
-    val statisticFns = selectedStatistics.map { stats =>
-      if (stats.endsWith("%")) {
-        val index = percentileIndex
-        percentileIndex += 1
-        (child: Expression) =>
-          GetArrayItem(
-            new ApproximatePercentile(castAsDoubleIfNecessary(child),
-              Literal(new GenericArrayData(percentiles), ArrayType(DoubleType, 
false)))
-              .toAggregateExpression(),
-            Literal(index))
-      } else {
-        stats.toLowerCase(Locale.ROOT) match {
-          case "count" => (child: Expression) => 
Count(child).toAggregateExpression()
-          case "count_distinct" => (child: Expression) =>
-            Count(child).toAggregateExpression(isDistinct = true)
-          case "approx_count_distinct" => (child: Expression) =>
-            HyperLogLogPlusPlus(child).toAggregateExpression()
-          case "mean" => (child: Expression) =>
-            Average(castAsDoubleIfNecessary(child)).toAggregateExpression()
-          case "stddev" => (child: Expression) =>
-            StddevSamp(castAsDoubleIfNecessary(child)).toAggregateExpression()
-          case "min" => (child: Expression) => 
Min(child).toAggregateExpression()
-          case "max" => (child: Expression) => 
Max(child).toAggregateExpression()
-          case _ => throw 
QueryExecutionErrors.statisticNotRecognizedError(stats)
+    var mapColumns = Seq.empty[Column]
+    var columnNames = Seq.empty[String]
+
+    ds.schema.fields.foreach { field =>
+      if (field.dataType.isInstanceOf[NumericType] || 
field.dataType.isInstanceOf[StringType]) {
+        val column = col(field.name)
+        var casted = column
+        if (field.dataType.isInstanceOf[StringType]) {
+          casted = new Column(Cast(column.expr, DoubleType, evalMode = 
EvalMode.TRY))
         }
-      }
-    }
 
-    val selectedCols = ds.logicalPlan.output
-      .filter(a => a.dataType.isInstanceOf[NumericType] || 
a.dataType.isInstanceOf[StringType])
+        val percentilesCol = if (percentiles.nonEmpty) {
+          percentile_approx(casted, lit(percentiles),
+            lit(ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY))
+        } else null
 
-    val aggExprs = statisticFns.flatMap { func =>
-      selectedCols.map(c => Column(Cast(func(c), StringType)).as(c.name))
-    }
+        var aggColumns = Seq.empty[Column]
+        var percentileIndex = 0
+        selectedStatistics.foreach { stats =>
+          aggColumns :+= lit(stats)
 
-    // If there is no selected columns, we don't need to run this aggregate, 
so make it a lazy val.
-    lazy val aggResult = ds.select(aggExprs: 
_*).queryExecution.toRdd.collect().head
+          stats.toLowerCase(Locale.ROOT) match {
+            case "count" => aggColumns :+= count(column)
 
-    // We will have one row for each selected statistic in the result.
-    val result = Array.fill[InternalRow](selectedStatistics.length) {
-      // each row has the statistic name, and statistic values of each 
selected column.
-      new GenericInternalRow(selectedCols.length + 1)
-    }
+            case "count_distinct" => aggColumns :+= count_distinct(column)
+
+            case "approx_count_distinct" => aggColumns :+= 
approx_count_distinct(column)
 
-    var rowIndex = 0
-    while (rowIndex < result.length) {
-      val statsName = selectedStatistics(rowIndex)
-      result(rowIndex).update(0, UTF8String.fromString(statsName))
-      for (colIndex <- selectedCols.indices) {
-        val statsValue = aggResult.getUTF8String(rowIndex * 
selectedCols.length + colIndex)
-        result(rowIndex).update(colIndex + 1, statsValue)
+            case "mean" => aggColumns :+= avg(casted)
+
+            case "stddev" => aggColumns :+= stddev(casted)
+
+            case "min" => aggColumns :+= min(column)
+
+            case "max" => aggColumns :+= max(column)
+
+            case percentile if percentile.endsWith("%") =>
+              aggColumns :+= get(percentilesCol, lit(percentileIndex))
+              percentileIndex += 1
+
+            case _ => throw 
QueryExecutionErrors.statisticNotRecognizedError(stats)
+          }
+        }
+
+        // map { "count" -> "1024", "min" -> "1.0", ... }
+        mapColumns :+= map(aggColumns.map(_.cast(StringType)): 
_*).as(field.name)
+        columnNames :+= field.name
       }
-      rowIndex += 1
     }
 
-    // All columns are string type
-    val output = AttributeReference("summary", StringType)() +:
-      selectedCols.map(c => AttributeReference(c.name, StringType)())
-
-    Dataset.ofRows(ds.sparkSession, LocalRelation(output, result))
+    if (mapColumns.isEmpty) {
+      ds.sparkSession.createDataFrame(selectedStatistics.map(Tuple1.apply))
+        .withColumnRenamed("_1", "summary")
+    } else {
+      val valueColumns = columnNames.map { columnName =>
+        new Column(ElementAt(col(columnName).expr, 
col("summary").expr)).as(columnName)
+      }
+      ds.select(mapColumns: _*)
+        .withColumn("summary", explode(lit(selectedStatistics)))
+        .select(Array(col("summary")) ++ valueColumns: _*)
+    }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to