This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new 4a4e35a30c7 [SPARK-38997][SQL] DS V2 aggregate push-down supports 
group by expressions
4a4e35a30c7 is described below

commit 4a4e35a30c7bb7534aece8e917a2813d47c2c498
Author: Jiaan Geng <belie...@163.com>
AuthorDate: Thu Apr 28 00:43:55 2022 +0800

    [SPARK-38997][SQL] DS V2 aggregate push-down supports group by expressions
    
    ### What changes were proposed in this pull request?
    Currently, Spark DS V2 aggregate push-down only supports group by column.
    But the SQL show below is very useful and common.
    ```
    SELECT
      CASE
        WHEN 'SALARY' > 8000.00
          AND 'SALARY' < 10000.00
        THEN 'SALARY'
        ELSE 0.00
      END AS key,
      SUM('SALARY')
    FROM "test"."employee"
    GROUP BY key
    ```
    
    ### Why are the changes needed?
    Let DS V2 aggregate push-down supports group by expressions
    
    ### Does this PR introduce _any_ user-facing change?
    'No'.
    New feature.
    
    ### How was this patch tested?
    New tests
    
    Closes #36325 from beliefer/SPARK-38997.
    
    Authored-by: Jiaan Geng <belie...@163.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit ee6ea3c68694e35c36ad006a7762297800d1e463)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../expressions/aggregate/Aggregation.java         |  10 +-
 .../spark/sql/execution/DataSourceScanExec.scala   |   2 +-
 .../datasources/AggregatePushDownUtils.scala       |  23 ++--
 .../execution/datasources/DataSourceStrategy.scala |   7 +-
 .../sql/execution/datasources/orc/OrcUtils.scala   |   2 +-
 .../datasources/parquet/ParquetUtils.scala         |   2 +-
 .../datasources/v2/V2ScanRelationPushDown.scala    |  23 ++--
 .../datasources/v2/jdbc/JDBCScanBuilder.scala      |  27 ++---
 .../sql/execution/datasources/v2/orc/OrcScan.scala |   2 +-
 .../datasources/v2/parquet/ParquetScan.scala       |   2 +-
 .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala    | 120 ++++++++++++++++-----
 11 files changed, 151 insertions(+), 69 deletions(-)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Aggregation.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Aggregation.java
index cf7dbb2978d..11d9e475ca1 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Aggregation.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Aggregation.java
@@ -20,7 +20,7 @@ package org.apache.spark.sql.connector.expressions.aggregate;
 import java.io.Serializable;
 
 import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.connector.expressions.NamedReference;
+import org.apache.spark.sql.connector.expressions.Expression;
 
 /**
  * Aggregation in SQL statement.
@@ -30,14 +30,14 @@ import 
org.apache.spark.sql.connector.expressions.NamedReference;
 @Evolving
 public final class Aggregation implements Serializable {
   private final AggregateFunc[] aggregateExpressions;
-  private final NamedReference[] groupByColumns;
+  private final Expression[] groupByExpressions;
 
-  public Aggregation(AggregateFunc[] aggregateExpressions, NamedReference[] 
groupByColumns) {
+  public Aggregation(AggregateFunc[] aggregateExpressions, Expression[] 
groupByExpressions) {
     this.aggregateExpressions = aggregateExpressions;
-    this.groupByColumns = groupByColumns;
+    this.groupByExpressions = groupByExpressions;
   }
 
   public AggregateFunc[] aggregateExpressions() { return aggregateExpressions; 
}
 
-  public NamedReference[] groupByColumns() { return groupByColumns; }
+  public Expression[] groupByExpressions() { return groupByExpressions; }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index 5067cd7fa3c..ac0f3af5725 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -163,7 +163,7 @@ case class RowDataSourceScanExec(
       "PushedFilters" -> pushedFilters) ++
       pushedDownOperators.aggregation.fold(Map[String, String]()) { v =>
         Map("PushedAggregates" -> 
seqToString(v.aggregateExpressions.map(_.describe())),
-          "PushedGroupByColumns" -> 
seqToString(v.groupByColumns.map(_.describe())))} ++
+          "PushedGroupByExpressions" -> 
seqToString(v.groupByExpressions.map(_.describe())))} ++
       topNOrLimitInfo ++
       pushedDownOperators.sample.map(v => "PushedSample" ->
         s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} 
SEED(${v.seed})"
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala
index 4779a3eaf25..97ee3cd661b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Expression, 
GenericInternalRow}
+import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, 
FieldReference}
 import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, 
Aggregation, Count, CountStar, Max, Min}
 import org.apache.spark.sql.execution.RowToColumnConverter
 import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils
@@ -93,8 +94,8 @@ object AggregatePushDownUtils {
       return None
     }
 
-    if (aggregation.groupByColumns.nonEmpty &&
-      partitionNames.size != aggregation.groupByColumns.length) {
+    if (aggregation.groupByExpressions.nonEmpty &&
+      partitionNames.size != aggregation.groupByExpressions.length) {
       // If there are group by columns, we only push down if the group by 
columns are the same as
       // the partition columns. In theory, if group by columns are a subset of 
partition columns,
       // we should still be able to push down. e.g. if table t has partition 
columns p1, p2, and p3,
@@ -106,11 +107,11 @@ object AggregatePushDownUtils {
       // aggregate push down simple and don't handle this complicate case for 
now.
       return None
     }
-    aggregation.groupByColumns.foreach { col =>
+    aggregation.groupByExpressions.map(extractColName).foreach { colName =>
       // don't push down if the group by columns are not the same as the 
partition columns (orders
       // doesn't matter because reorder can be done at data source layer)
-      if (col.fieldNames.length != 1 || !isPartitionCol(col.fieldNames.head)) 
return None
-      finalSchema = finalSchema.add(getStructFieldForCol(col.fieldNames.head))
+      if (colName.isEmpty || !isPartitionCol(colName.get)) return None
+      finalSchema = finalSchema.add(getStructFieldForCol(colName.get))
     }
 
     aggregation.aggregateExpressions.foreach {
@@ -137,7 +138,8 @@ object AggregatePushDownUtils {
   def equivalentAggregations(a: Aggregation, b: Aggregation): Boolean = {
     a.aggregateExpressions.sortBy(_.hashCode())
       .sameElements(b.aggregateExpressions.sortBy(_.hashCode())) &&
-      
a.groupByColumns.sortBy(_.hashCode()).sameElements(b.groupByColumns.sortBy(_.hashCode()))
+      a.groupByExpressions.sortBy(_.hashCode())
+        .sameElements(b.groupByExpressions.sortBy(_.hashCode()))
   }
 
   /**
@@ -164,7 +166,7 @@ object AggregatePushDownUtils {
   def getSchemaWithoutGroupingExpression(
       aggSchema: StructType,
       aggregation: Aggregation): StructType = {
-    val numOfGroupByColumns = aggregation.groupByColumns.length
+    val numOfGroupByColumns = aggregation.groupByExpressions.length
     if (numOfGroupByColumns > 0) {
       new StructType(aggSchema.fields.drop(numOfGroupByColumns))
     } else {
@@ -179,7 +181,7 @@ object AggregatePushDownUtils {
       partitionSchema: StructType,
       aggregation: Aggregation,
       partitionValues: InternalRow): InternalRow = {
-    val groupByColNames = aggregation.groupByColumns.map(_.fieldNames.head)
+    val groupByColNames = 
aggregation.groupByExpressions.flatMap(extractColName)
     assert(groupByColNames.length == partitionSchema.length &&
       groupByColNames.length == partitionValues.numFields, "The number of 
group by columns " +
       s"${groupByColNames.length} should be the same as partition schema 
length " +
@@ -197,4 +199,9 @@ object AggregatePushDownUtils {
       partitionValues
     }
   }
+
+  private def extractColName(v2Expr: V2Expression): Option[String] = v2Expr 
match {
+    case f: FieldReference if f.fieldNames.length == 1 => 
Some(f.fieldNames.head)
+    case _ => None
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 1b14884e759..e35d0932076 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -759,14 +759,13 @@ object DataSourceStrategy
   protected[sql] def translateAggregation(
       aggregates: Seq[AggregateExpression], groupBy: Seq[Expression]): 
Option[Aggregation] = {
 
-    def columnAsString(e: Expression): Option[FieldReference] = e match {
-      case PushableColumnWithoutNestedColumn(name) =>
-        Some(FieldReference.column(name).asInstanceOf[FieldReference])
+    def translateGroupBy(e: Expression): Option[V2Expression] = e match {
+      case PushableExpression(expr) => Some(expr)
       case _ => None
     }
 
     val translatedAggregates = aggregates.flatMap(translateAggregate)
-    val translatedGroupBys = groupBy.flatMap(columnAsString)
+    val translatedGroupBys = groupBy.flatMap(translateGroupBy)
 
     if (translatedAggregates.length != aggregates.length ||
       translatedGroupBys.length != groupBy.length) {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
index 9011821e1a7..03c29894cb2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
@@ -519,7 +519,7 @@ object OrcUtils extends Logging {
     val orcValuesDeserializer = new OrcDeserializer(schemaWithoutGroupBy,
       (0 until schemaWithoutGroupBy.length).toArray)
     val resultRow = orcValuesDeserializer.deserializeFromValues(aggORCValues)
-    if (aggregation.groupByColumns.nonEmpty) {
+    if (aggregation.groupByExpressions.nonEmpty) {
       val reOrderedPartitionValues = 
AggregatePushDownUtils.reOrderPartitionCol(
         partitionSchema, aggregation, partitionValues)
       new JoinedRow(reOrderedPartitionValues, resultRow)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
index 5a291e6a2e5..7c0348d5833 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
@@ -279,7 +279,7 @@ object ParquetUtils {
         throw new SparkException("Unexpected parquet type name: " + 
primitiveTypeNames(i))
     }
 
-    if (aggregation.groupByColumns.nonEmpty) {
+    if (aggregation.groupByExpressions.nonEmpty) {
       val reorderedPartitionValues = 
AggregatePushDownUtils.reOrderPartitionCol(
         partitionSchema, aggregation, partitionValues)
       new JoinedRow(reorderedPartitionValues, converter.currentRecord)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index 92859f94888..20d508df568 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -183,9 +183,14 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] 
with PredicateHelper wit
                   // scalastyle:on
                   val newOutput = scan.readSchema().toAttributes
                   assert(newOutput.length == groupingExpressions.length + 
finalAggregates.length)
-                  val groupAttrs = 
normalizedGroupingExpressions.zip(newOutput).map {
-                    case (a: Attribute, b: Attribute) => b.withExprId(a.exprId)
-                    case (_, b) => b
+                  val groupByExprToOutputOrdinal = 
mutable.HashMap.empty[Expression, Int]
+                  val groupAttrs = 
normalizedGroupingExpressions.zip(newOutput).zipWithIndex.map {
+                    case ((a: Attribute, b: Attribute), _) => 
b.withExprId(a.exprId)
+                    case ((expr, attr), ordinal) =>
+                      if 
(!groupByExprToOutputOrdinal.contains(expr.canonicalized)) {
+                        groupByExprToOutputOrdinal(expr.canonicalized) = 
ordinal
+                      }
+                      attr
                   }
                   val aggOutput = newOutput.drop(groupAttrs.length)
                   val output = groupAttrs ++ aggOutput
@@ -196,7 +201,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] 
with PredicateHelper wit
                        |Pushed Aggregate Functions:
                        | 
${pushedAggregates.get.aggregateExpressions.mkString(", ")}
                        |Pushed Group by:
-                       | ${pushedAggregates.get.groupByColumns.mkString(", ")}
+                       | ${pushedAggregates.get.groupByExpressions.mkString(", 
")}
                        |Output: ${output.mkString(", ")}
                       """.stripMargin)
 
@@ -205,14 +210,15 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] 
with PredicateHelper wit
                     DataSourceV2ScanRelation(sHolder.relation, wrappedScan, 
output)
                   if (r.supportCompletePushDown(pushedAggregates.get)) {
                     val projectExpressions = finalResultExpressions.map { expr 
=>
-                      // TODO At present, only push down group by attribute is 
supported.
-                      // In future, more attribute conversion is extended 
here. e.g. GetStructField
-                      expr.transform {
+                      expr.transformDown {
                         case agg: AggregateExpression =>
                           val ordinal = 
aggExprToOutputOrdinal(agg.canonicalized)
                           val child =
                             addCastIfNeeded(aggOutput(ordinal), 
agg.resultAttribute.dataType)
                           Alias(child, 
agg.resultAttribute.name)(agg.resultAttribute.exprId)
+                        case expr if 
groupByExprToOutputOrdinal.contains(expr.canonicalized) =>
+                          val ordinal = 
groupByExprToOutputOrdinal(expr.canonicalized)
+                          addCastIfNeeded(groupAttrs(ordinal), expr.dataType)
                       }
                     }.asInstanceOf[Seq[NamedExpression]]
                     Project(projectExpressions, scanRelation)
@@ -255,6 +261,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] 
with PredicateHelper wit
                             case other => other
                           }
                         agg.copy(aggregateFunction = aggFunction)
+                      case expr if 
groupByExprToOutputOrdinal.contains(expr.canonicalized) =>
+                        val ordinal = 
groupByExprToOutputOrdinal(expr.canonicalized)
+                        addCastIfNeeded(groupAttrs(ordinal), expr.dataType)
                     }
                   }
                 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
index 0a1542a4295..8b378d2d87c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
@@ -20,7 +20,7 @@ import scala.util.control.NonFatal
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.connector.expressions.SortOrder
+import org.apache.spark.sql.connector.expressions.{FieldReference, SortOrder}
 import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
 import org.apache.spark.sql.connector.expressions.filter.Predicate
 import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, 
SupportsPushDownAggregates, SupportsPushDownLimit, 
SupportsPushDownRequiredColumns, SupportsPushDownTableSample, 
SupportsPushDownTopN, SupportsPushDownV2Filters}
@@ -70,12 +70,15 @@ case class JDBCScanBuilder(
 
   private var pushedAggregateList: Array[String] = Array()
 
-  private var pushedGroupByCols: Option[Array[String]] = None
+  private var pushedGroupBys: Option[Array[String]] = None
 
   override def supportCompletePushDown(aggregation: Aggregation): Boolean = {
-    lazy val fieldNames = aggregation.groupByColumns()(0).fieldNames()
+    lazy val fieldNames = aggregation.groupByExpressions()(0) match {
+      case field: FieldReference => field.fieldNames
+      case _ => Array.empty[String]
+    }
     jdbcOptions.numPartitions.map(_ == 1).getOrElse(true) ||
-      (aggregation.groupByColumns().length == 1 && fieldNames.length == 1 &&
+      (aggregation.groupByExpressions().length == 1 && fieldNames.length == 1 
&&
         jdbcOptions.partitionColumn.exists(fieldNames(0).equalsIgnoreCase(_)))
   }
 
@@ -86,20 +89,18 @@ case class JDBCScanBuilder(
     val compiledAggs = 
aggregation.aggregateExpressions.flatMap(dialect.compileAggregate)
     if (compiledAggs.length != aggregation.aggregateExpressions.length) return 
false
 
-    val groupByCols = aggregation.groupByColumns.map { col =>
-      if (col.fieldNames.length != 1) return false
-      dialect.quoteIdentifier(col.fieldNames.head)
-    }
+    val compiledGroupBys = 
aggregation.groupByExpressions.flatMap(dialect.compileExpression)
+    if (compiledGroupBys.length != aggregation.groupByExpressions.length) 
return false
 
     // The column names here are already quoted and can be used to build sql 
string directly.
     // e.g. "DEPT","NAME",MAX("SALARY"),MIN("BONUS") =>
     // SELECT "DEPT","NAME",MAX("SALARY"),MIN("BONUS") FROM "test"."employee"
     //   GROUP BY "DEPT", "NAME"
-    val selectList = groupByCols ++ compiledAggs
-    val groupByClause = if (groupByCols.isEmpty) {
+    val selectList = compiledGroupBys ++ compiledAggs
+    val groupByClause = if (compiledGroupBys.isEmpty) {
       ""
     } else {
-      "GROUP BY " + groupByCols.mkString(",")
+      "GROUP BY " + compiledGroupBys.mkString(",")
     }
 
     val aggQuery = s"SELECT ${selectList.mkString(",")} FROM 
${jdbcOptions.tableOrQuery} " +
@@ -107,7 +108,7 @@ case class JDBCScanBuilder(
     try {
       finalSchema = JDBCRDD.getQueryOutputSchema(aggQuery, jdbcOptions, 
dialect)
       pushedAggregateList = selectList
-      pushedGroupByCols = Some(groupByCols)
+      pushedGroupBys = Some(compiledGroupBys)
       true
     } catch {
       case NonFatal(e) =>
@@ -173,6 +174,6 @@ case class JDBCScanBuilder(
     // prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" 
and can't
     // be used in sql string.
     JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, 
pushedPredicate,
-      pushedAggregateList, pushedGroupByCols, tableSample, pushedLimit, 
sortOrders)
+      pushedAggregateList, pushedGroupBys, tableSample, pushedLimit, 
sortOrders)
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
index baf307257c3..d7baf493881 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
@@ -83,7 +83,7 @@ case class OrcScan(
 
   lazy private val (pushedAggregationsStr, pushedGroupByStr) = if 
(pushedAggregate.nonEmpty) {
     (seqToString(pushedAggregate.get.aggregateExpressions),
-      seqToString(pushedAggregate.get.groupByColumns))
+      seqToString(pushedAggregate.get.groupByExpressions))
   } else {
     ("[]", "[]")
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala
index 6b35f2406a8..99632d79cd8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala
@@ -116,7 +116,7 @@ case class ParquetScan(
 
   lazy private val (pushedAggregationsStr, pushedGroupByStr) = if 
(pushedAggregate.nonEmpty) {
     (seqToString(pushedAggregate.get.aggregateExpressions),
-      seqToString(pushedAggregate.get.groupByColumns))
+      seqToString(pushedAggregate.get.groupByExpressions))
   } else {
     ("[]", "[]")
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index e28d9ba9ba8..30dbc7bd609 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -171,7 +171,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
       .groupBy("DEPT").sum("SALARY")
       .limit(1)
     checkPushedInfo(df4,
-      "PushedAggregates: [SUM(SALARY)], PushedFilters: [], 
PushedGroupByColumns: [DEPT], ")
+      "PushedAggregates: [SUM(SALARY)], PushedFilters: [], 
PushedGroupByExpressions: [DEPT], ")
     checkAnswer(df4, Seq(Row(1, 19000.00)))
 
     val name = udf { (x: String) => x.matches("cat|dav|amy") }
@@ -257,7 +257,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
       .limit(1)
     checkSortRemoved(df6, false)
     checkPushedInfo(df6, "PushedAggregates: [SUM(SALARY)]," +
-      " PushedFilters: [], PushedGroupByColumns: [DEPT], ")
+      " PushedFilters: [], PushedGroupByExpressions: [DEPT], ")
     checkAnswer(df6, Seq(Row(1, 19000.00)))
 
     val name = udf { (x: String) => x.matches("cat|dav|amy") }
@@ -609,7 +609,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
     checkAggregateRemoved(df)
     checkPushedInfo(df, "PushedAggregates: [MAX(SALARY), AVG(BONUS)], " +
       "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " +
-      "PushedGroupByColumns: [DEPT], ")
+      "PushedGroupByExpressions: [DEPT], ")
     checkAnswer(df, Seq(Row(10000, 1100.0), Row(12000, 1250.0), Row(12000, 
1200.0)))
   }
 
@@ -630,7 +630,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
     checkAggregateRemoved(df)
     checkPushedInfo(df, "PushedAggregates: [MAX(ID), AVG(ID)], " +
       "PushedFilters: [ID IS NOT NULL, ID > 0], " +
-      "PushedGroupByColumns: [], ")
+      "PushedGroupByExpressions: [], ")
     checkAnswer(df, Seq(Row(2, 1.5)))
   }
 
@@ -712,18 +712,84 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
   }
 
   test("scan with aggregate push-down: SUM with group by") {
-    val df = sql("SELECT SUM(SALARY) FROM h2.test.employee GROUP BY DEPT")
-    checkAggregateRemoved(df)
-    checkPushedInfo(df, "PushedAggregates: [SUM(SALARY)], " +
-      "PushedFilters: [], PushedGroupByColumns: [DEPT], ")
-    checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000)))
+    val df1 = sql("SELECT SUM(SALARY) FROM h2.test.employee GROUP BY DEPT")
+    checkAggregateRemoved(df1)
+    checkPushedInfo(df1, "PushedAggregates: [SUM(SALARY)], " +
+      "PushedFilters: [], PushedGroupByExpressions: [DEPT], ")
+    checkAnswer(df1, Seq(Row(19000), Row(22000), Row(12000)))
+
+    val df2 = sql(
+      """
+        |SELECT CASE WHEN SALARY > 8000 AND SALARY < 10000 THEN SALARY ELSE 0 
END as key,
+        |  SUM(SALARY) FROM h2.test.employee GROUP BY key""".stripMargin)
+    checkAggregateRemoved(df2)
+    checkPushedInfo(df2,
+      """
+        |PushedAggregates: [SUM(SALARY)],
+        |PushedFilters: [],
+        |PushedGroupByExpressions:
+        |[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY 
ELSE 0.00 END],
+        |""".stripMargin.replaceAll("\n", " "))
+    checkAnswer(df2, Seq(Row(0, 44000), Row(9000, 9000)))
+
+    val df3 = spark.read
+      .option("partitionColumn", "dept")
+      .option("lowerBound", "0")
+      .option("upperBound", "2")
+      .option("numPartitions", "2")
+      .table("h2.test.employee")
+      .groupBy(when(($"SALARY" > 8000).and($"SALARY" < 10000), 
$"SALARY").otherwise(0).as("key"))
+      .agg(sum($"SALARY"))
+    checkAggregateRemoved(df3, false)
+    checkPushedInfo(df3,
+      """
+        |PushedAggregates: [SUM(SALARY)],
+        |PushedFilters: [],
+        |PushedGroupByExpressions:
+        |[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY 
ELSE 0.00 END],
+        |""".stripMargin.replaceAll("\n", " "))
+    checkAnswer(df3, Seq(Row(0, 44000), Row(9000, 9000)))
+
+    val df4 = sql(
+      """
+        |SELECT DEPT, CASE WHEN SALARY > 8000 AND SALARY < 10000 THEN SALARY 
ELSE 0 END as key,
+        |  SUM(SALARY) FROM h2.test.employee GROUP BY DEPT, key""".stripMargin)
+    checkAggregateRemoved(df4)
+    checkPushedInfo(df4,
+      """
+        |PushedAggregates: [SUM(SALARY)],
+        |PushedFilters: [],
+        |PushedGroupByExpressions:
+        |[DEPT, CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN 
SALARY ELSE 0.00 END],
+        |""".stripMargin.replaceAll("\n", " "))
+    checkAnswer(df4, Seq(Row(1, 0, 10000), Row(1, 9000, 9000), Row(2, 0, 
22000), Row(6, 0, 12000)))
+
+    val df5 = spark.read
+      .option("partitionColumn", "dept")
+      .option("lowerBound", "0")
+      .option("upperBound", "2")
+      .option("numPartitions", "2")
+      .table("h2.test.employee")
+      .groupBy($"DEPT",
+        when(($"SALARY" > 8000).and($"SALARY" < 10000), $"SALARY").otherwise(0)
+          .as("key"))
+      .agg(sum($"SALARY"))
+    checkAggregateRemoved(df5, false)
+    checkPushedInfo(df5,
+      """
+        |PushedAggregates: [SUM(SALARY)],
+        |PushedFilters: [],
+        |PushedGroupByExpressions:
+        |[DEPT, CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN 
SALARY ELSE 0.00 END],
+        |""".stripMargin.replaceAll("\n", " "))
+    checkAnswer(df5, Seq(Row(1, 0, 10000), Row(1, 9000, 9000), Row(2, 0, 
22000), Row(6, 0, 12000)))
   }
 
   test("scan with aggregate push-down: DISTINCT SUM with group by") {
     val df = sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee GROUP BY 
DEPT")
     checkAggregateRemoved(df)
     checkPushedInfo(df, "PushedAggregates: [SUM(DISTINCT SALARY)], " +
-      "PushedFilters: [], PushedGroupByColumns: [DEPT]")
+      "PushedFilters: [], PushedGroupByExpressions: [DEPT]")
     checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000)))
   }
 
@@ -733,7 +799,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
     checkFiltersRemoved(df)
     checkAggregateRemoved(df)
     checkPushedInfo(df, "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " +
-      "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: 
[DEPT, NAME]")
+      "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: 
[DEPT, NAME]")
     checkAnswer(df, Seq(Row(9000, 1200), Row(12000, 1200), Row(10000, 1300),
       Row(10000, 1000), Row(12000, 1200)))
   }
@@ -747,7 +813,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
     assert(filters1.isEmpty)
     checkAggregateRemoved(df1)
     checkPushedInfo(df1, "PushedAggregates: [MAX(SALARY)], " +
-      "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: 
[DEPT, NAME]")
+      "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: 
[DEPT, NAME]")
     checkAnswer(df1, Seq(Row("1#amy", 10000), Row("1#cathy", 9000), 
Row("2#alex", 12000),
       Row("2#david", 10000), Row("6#jen", 12000)))
 
@@ -759,7 +825,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
     assert(filters2.isEmpty)
     checkAggregateRemoved(df2)
     checkPushedInfo(df2, "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " +
-      "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: 
[DEPT, NAME]")
+      "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: 
[DEPT, NAME]")
     checkAnswer(df2, Seq(Row("1#amy", 11000), Row("1#cathy", 10200), 
Row("2#alex", 13200),
       Row("2#david", 11300), Row("6#jen", 13200)))
 
@@ -779,7 +845,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
     checkFiltersRemoved(df, false)
     checkAggregateRemoved(df)
     checkPushedInfo(df, "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " +
-      "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: 
[DEPT]")
+      "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: 
[DEPT]")
     checkAnswer(df, Seq(Row(12000, 1200), Row(12000, 1200)))
   }
 
@@ -789,7 +855,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
       .min("SALARY").as("total")
     checkAggregateRemoved(df)
     checkPushedInfo(df, "PushedAggregates: [MIN(SALARY)], " +
-      "PushedFilters: [], PushedGroupByColumns: [DEPT]")
+      "PushedFilters: [], PushedGroupByExpressions: [DEPT]")
     checkAnswer(df, Seq(Row(1, 9000), Row(2, 10000), Row(6, 12000)))
   }
 
@@ -804,7 +870,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
     checkFiltersRemoved(query, false)// filter over aggregate not pushed down
     checkAggregateRemoved(query)
     checkPushedInfo(query, "PushedAggregates: [SUM(SALARY)], " +
-      "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: 
[DEPT]")
+      "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: 
[DEPT]")
     checkAnswer(query, Seq(Row(6, 12000), Row(1, 19000), Row(2, 22000)))
   }
 
@@ -836,7 +902,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
     checkFiltersRemoved(df)
     checkAggregateRemoved(df)
     checkPushedInfo(df, "PushedAggregates: [VAR_POP(BONUS), VAR_SAMP(BONUS)], 
" +
-      "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: 
[DEPT]")
+      "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: 
[DEPT]")
     checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null)))
   }
 
@@ -846,7 +912,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
     checkFiltersRemoved(df)
     checkAggregateRemoved(df)
     checkPushedInfo(df, "PushedAggregates: [STDDEV_POP(BONUS), 
STDDEV_SAMP(BONUS)], " +
-      "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: 
[DEPT]")
+      "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: 
[DEPT]")
     checkAnswer(df, Seq(Row(100d, 141.4213562373095d), Row(50d, 
70.71067811865476d), Row(0d, null)))
   }
 
@@ -856,7 +922,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
     checkFiltersRemoved(df)
     checkAggregateRemoved(df)
     checkPushedInfo(df, "PushedAggregates: [COVAR_POP(BONUS, BONUS), 
COVAR_SAMP(BONUS, BONUS)], " +
-      "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: 
[DEPT]")
+      "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: 
[DEPT]")
     checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null)))
   }
 
@@ -866,7 +932,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
     checkFiltersRemoved(df)
     checkAggregateRemoved(df)
     checkPushedInfo(df, "PushedAggregates: [CORR(BONUS, BONUS)], " +
-      "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: 
[DEPT]")
+      "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: 
[DEPT]")
     checkAnswer(df, Seq(Row(1d), Row(1d), Row(null)))
   }
 
@@ -878,7 +944,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
     df2.queryExecution.optimizedPlan.collect {
       case relation: DataSourceV2ScanRelation =>
         val expectedPlanFragment =
-          "PushedAggregates: [SUM(SALARY)], PushedFilters: [], 
PushedGroupByColumns: []"
+          "PushedAggregates: [SUM(SALARY)], PushedFilters: [], 
PushedGroupByExpressions: []"
         checkKeywordsExistsInExplain(df2, expectedPlanFragment)
         relation.scan match {
           case v1: V1ScanWrapper =>
@@ -931,7 +997,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
       "PushedAggregates: [COUNT(CASE WHEN (SALARY > 8000.00) AND (SALARY < 
10000.00)" +
       " THEN SALARY ELSE 0.00 END), COUNT(CAS..., " +
       "PushedFilters: [], " +
-      "PushedGroupByColumns: [DEPT], ")
+      "PushedGroupByExpressions: [DEPT], ")
     checkAnswer(df, Seq(Row(1, 1, 1, 1, 1, 0d, 12000d, 0d, 12000d, 0d, 0d, 2, 
0d),
       Row(2, 2, 2, 2, 2, 10000d, 12000d, 10000d, 12000d, 0d, 0d, 3, 0d),
       Row(2, 2, 2, 2, 2, 10000d, 9000d, 10000d, 10000d, 9000d, 0d, 2, 0d)))
@@ -945,7 +1011,7 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
         val expectedPlanFragment = if (ansiMode) {
           "PushedAggregates: [SUM(2147483647 + DEPT)], " +
             "PushedFilters: [], " +
-            "PushedGroupByColumns: []"
+            "PushedGroupByExpressions: []"
         } else {
           "PushedFilters: []"
         }
@@ -1094,7 +1160,7 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
       .filter($"total" > 1000)
     checkAggregateRemoved(df)
     checkPushedInfo(df,
-      "PushedAggregates: [SUM(SALARY)], PushedFilters: [], 
PushedGroupByColumns: [DEPT]")
+      "PushedAggregates: [SUM(SALARY)], PushedFilters: [], 
PushedGroupByExpressions: [DEPT]")
     checkAnswer(df, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00)))
 
     val df2 = spark.table("h2.test.employee")
@@ -1104,7 +1170,7 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
       .filter($"total" > 1000)
     checkAggregateRemoved(df2)
     checkPushedInfo(df2,
-      "PushedAggregates: [SUM(SALARY)], PushedFilters: [], 
PushedGroupByColumns: [DEPT]")
+      "PushedAggregates: [SUM(SALARY)], PushedFilters: [], 
PushedGroupByExpressions: [DEPT]")
     checkAnswer(df2, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00)))
   }
 
@@ -1121,7 +1187,7 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
       .filter($"total" > 1000)
     checkAggregateRemoved(df, false)
     checkPushedInfo(df,
-      "PushedAggregates: [SUM(SALARY)], PushedFilters: [], 
PushedGroupByColumns: [NAME]")
+      "PushedAggregates: [SUM(SALARY)], PushedFilters: [], 
PushedGroupByExpressions: [NAME]")
     checkAnswer(df, Seq(Row("alex", 12000.00), Row("amy", 10000.00),
       Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00)))
 
@@ -1137,7 +1203,7 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
       .filter($"total" > 1000)
     checkAggregateRemoved(df2, false)
     checkPushedInfo(df2,
-      "PushedAggregates: [SUM(SALARY)], PushedFilters: [], 
PushedGroupByColumns: [NAME]")
+      "PushedAggregates: [SUM(SALARY)], PushedFilters: [], 
PushedGroupByExpressions: [NAME]")
     checkAnswer(df2, Seq(Row("alex", 12000.00), Row("amy", 10000.00),
       Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00)))
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to