This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d9e3d9b  [SPARK-37644][SQL] Support datasource v2 complete aggregate 
pushdown
d9e3d9b is described below

commit d9e3d9b9d97e7a238062060675913e29b9184cfb
Author: Jiaan Geng <belie...@163.com>
AuthorDate: Thu Dec 23 23:05:20 2021 +0800

    [SPARK-37644][SQL] Support datasource v2 complete aggregate pushdown
    
    ### What changes were proposed in this pull request?
    Currently , Spark supports push down aggregate with partial-agg and 
final-agg . For some data source (e.g. JDBC ) , we can avoid partial-agg and 
final-agg by running completely on database.
    
    ### Why are the changes needed?
    Improve performance for aggregate pushdown.
    
    ### Does this PR introduce _any_ user-facing change?
    'No'. Just change the inner implement.
    
    ### How was this patch tested?
    New tests.
    
    Closes #34904 from beliefer/SPARK-37644.
    
    Authored-by: Jiaan Geng <belie...@163.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../connector/read/SupportsPushDownAggregates.java |   8 ++
 .../datasources/v2/V2ScanRelationPushDown.scala    | 101 +++++++++++++--------
 .../datasources/v2/jdbc/JDBCScanBuilder.scala      |   3 +
 .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala    | 101 ++++++++++++++++++++-
 4 files changed, 173 insertions(+), 40 deletions(-)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java
index 3e643b5..4e6c59e 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java
@@ -46,6 +46,14 @@ import 
org.apache.spark.sql.connector.expressions.aggregate.Aggregation;
 public interface SupportsPushDownAggregates extends ScanBuilder {
 
   /**
+   * Whether the datasource support complete aggregation push-down. Spark 
could avoid partial-agg
+   * and final-agg when the aggregation operation can be pushed down to the 
datasource completely.
+   *
+   * @return true if the aggregation can be pushed down to datasource 
completely, false otherwise.
+   */
+  default boolean supportCompletePushDown() { return false; }
+
+  /**
    * Pushes down Aggregation to datasource. The order of the datasource scan 
output columns should
    * be: grouping columns, aggregate columns (in the same order as the 
aggregate functions in
    * the given Aggregation).
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index e7c06d0..3a792f4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2
 
 import scala.collection.mutable
 
-import org.apache.spark.sql.catalyst.expressions.{And, Attribute, 
AttributeReference, Expression, IntegerLiteral, NamedExpression, 
PredicateHelper, ProjectionOverSchema, SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, 
AttributeReference, Cast, Expression, IntegerLiteral, NamedExpression, 
PredicateHelper, ProjectionOverSchema, SubqueryExpression}
 import org.apache.spark.sql.catalyst.expressions.aggregate
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.planning.ScanOperation
@@ -30,7 +30,7 @@ import 
org.apache.spark.sql.connector.expressions.aggregate.Aggregation
 import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, 
SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
 import org.apache.spark.sql.execution.datasources.DataSourceStrategy
 import org.apache.spark.sql.sources
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{DataType, LongType, StructType}
 import org.apache.spark.sql.util.SchemaUtils._
 
 object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
@@ -131,7 +131,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] 
with PredicateHelper {
                   case (a: Attribute, b: Attribute) => b.withExprId(a.exprId)
                   case (_, b) => b
                 }
-                val output = groupAttrs ++ newOutput.drop(groupAttrs.length)
+                val aggOutput = newOutput.drop(groupAttrs.length)
+                val output = groupAttrs ++ aggOutput
 
                 logInfo(
                   s"""
@@ -147,40 +148,59 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] 
with PredicateHelper {
 
                 val scanRelation = DataSourceV2ScanRelation(sHolder.relation, 
wrappedScan, output)
 
-                val plan = Aggregate(
-                  output.take(groupingExpressions.length), resultExpressions, 
scanRelation)
-
-                // scalastyle:off
-                // Change the optimized logical plan to reflect the pushed 
down aggregate
-                // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
-                // SELECT min(c1), max(c1) FROM t GROUP BY c2;
-                // The original logical plan is
-                // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS 
max(c1)#18]
-                // +- RelationV2[c1#9, c2#10] ...
-                //
-                // After change the V2ScanRelation output to [c2#10, 
min(c1)#21, max(c1)#22]
-                // we have the following
-                // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS 
max(c1)#18]
-                // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
-                //
-                // We want to change it to
-                // == Optimized Logical Plan ==
-                // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, 
max(max(c1)#22) AS max(c1)#18]
-                // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
-                // scalastyle:on
-                val aggOutput = output.drop(groupAttrs.length)
-                plan.transformExpressions {
-                  case agg: AggregateExpression =>
-                    val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
-                    val aggFunction: aggregate.AggregateFunction =
-                      agg.aggregateFunction match {
-                        case max: aggregate.Max => max.copy(child = 
aggOutput(ordinal))
-                        case min: aggregate.Min => min.copy(child = 
aggOutput(ordinal))
-                        case sum: aggregate.Sum => sum.copy(child = 
aggOutput(ordinal))
-                        case _: aggregate.Count => 
aggregate.Sum(aggOutput(ordinal))
-                        case other => other
-                      }
-                    agg.copy(aggregateFunction = aggFunction)
+                if (r.supportCompletePushDown()) {
+                  val projectExpressions = resultExpressions.map { expr =>
+                    // TODO At present, only push down group by attribute is 
supported.
+                    // In future, more attribute conversion is extended here. 
e.g. GetStructField
+                    expr.transform {
+                      case agg: AggregateExpression =>
+                        val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
+                        val child =
+                          addCastIfNeeded(aggOutput(ordinal), 
agg.resultAttribute.dataType)
+                        Alias(child, 
agg.resultAttribute.name)(agg.resultAttribute.exprId)
+                    }
+                  }.asInstanceOf[Seq[NamedExpression]]
+                  Project(projectExpressions, scanRelation)
+                } else {
+                  val plan = Aggregate(
+                    output.take(groupingExpressions.length), 
resultExpressions, scanRelation)
+
+                  // scalastyle:off
+                  // Change the optimized logical plan to reflect the pushed 
down aggregate
+                  // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
+                  // SELECT min(c1), max(c1) FROM t GROUP BY c2;
+                  // The original logical plan is
+                  // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS 
max(c1)#18]
+                  // +- RelationV2[c1#9, c2#10] ...
+                  //
+                  // After change the V2ScanRelation output to [c2#10, 
min(c1)#21, max(c1)#22]
+                  // we have the following
+                  // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) 
AS max(c1)#18]
+                  // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
+                  //
+                  // We want to change it to
+                  // == Optimized Logical Plan ==
+                  // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, 
max(max(c1)#22) AS max(c1)#18]
+                  // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
+                  // scalastyle:on
+                  plan.transformExpressions {
+                    case agg: AggregateExpression =>
+                      val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
+                      val aggAttribute = aggOutput(ordinal)
+                      val aggFunction: aggregate.AggregateFunction =
+                        agg.aggregateFunction match {
+                          case max: aggregate.Max =>
+                            max.copy(child = addCastIfNeeded(aggAttribute, 
max.child.dataType))
+                          case min: aggregate.Min =>
+                            min.copy(child = addCastIfNeeded(aggAttribute, 
min.child.dataType))
+                          case sum: aggregate.Sum =>
+                            sum.copy(child = addCastIfNeeded(aggAttribute, 
sum.child.dataType))
+                          case _: aggregate.Count =>
+                            aggregate.Sum(addCastIfNeeded(aggAttribute, 
LongType))
+                          case other => other
+                        }
+                      agg.copy(aggregateFunction = aggFunction)
+                  }
                 }
               }
             case _ => aggNode
@@ -189,6 +209,13 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] 
with PredicateHelper {
       }
   }
 
+  private def addCastIfNeeded(aggAttribute: AttributeReference, aggDataType: 
DataType) =
+    if (aggAttribute.dataType == aggDataType) {
+      aggAttribute
+    } else {
+      Cast(aggAttribute, aggDataType)
+    }
+
   def applyColumnPruning(plan: LogicalPlan): LogicalPlan = plan.transform {
     case ScanOperation(project, filters, sHolder: ScanBuilderHolder) =>
       // column pruning
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
index 1760122..01722e8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
@@ -72,6 +72,9 @@ case class JDBCScanBuilder(
 
   private var pushedGroupByCols: Option[Array[String]] = None
 
+  override def supportCompletePushDown: Boolean =
+    jdbcOptions.numPartitions.map(_ == 1).getOrElse(true)
+
   override def pushAggregation(aggregation: Aggregation): Boolean = {
     if (!jdbcOptions.pushDownAggregate) return false
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index 5df875b..eb3c8d6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -23,7 +23,7 @@ import java.util.Properties
 import org.apache.spark.SparkConf
 import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest, Row}
 import 
org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, Sort}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort}
 import org.apache.spark.sql.connector.expressions.{FieldReference, 
NullOrdering, SortDirection, SortValue}
 import 
org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, 
V1ScanWrapper}
 import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
@@ -388,6 +388,17 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
     }
   }
 
+  private def checkAggregateRemoved(df: DataFrame, removed: Boolean = true): 
Unit = {
+    val aggregates = df.queryExecution.optimizedPlan.collect {
+      case agg: Aggregate => agg
+    }
+    if (removed) {
+      assert(aggregates.isEmpty)
+    } else {
+      assert(aggregates.nonEmpty)
+    }
+  }
+
   test("scan with aggregate push-down: MAX MIN with filter and group by") {
     val df = sql("select MAX(SaLaRY), MIN(BONUS) FROM h2.test.employee where 
dept > 0" +
       " group by DePt")
@@ -395,6 +406,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
       case f: Filter => f
     }
     assert(filters.isEmpty)
+    checkAggregateRemoved(df)
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
@@ -412,6 +424,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
       case f: Filter => f
     }
     assert(filters.isEmpty)
+    checkAggregateRemoved(df)
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
@@ -425,6 +438,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
 
   test("scan with aggregate push-down: aggregate + number") {
     val df = sql("select MAX(SALARY) + 1 FROM h2.test.employee")
+    checkAggregateRemoved(df)
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
@@ -436,6 +450,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
 
   test("scan with aggregate push-down: COUNT(*)") {
     val df = sql("select COUNT(*) FROM h2.test.employee")
+    checkAggregateRemoved(df)
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
@@ -447,6 +462,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
 
   test("scan with aggregate push-down: COUNT(col)") {
     val df = sql("select COUNT(DEPT) FROM h2.test.employee")
+    checkAggregateRemoved(df)
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
@@ -458,6 +474,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
 
   test("scan with aggregate push-down: COUNT(DISTINCT col)") {
     val df = sql("select COUNT(DISTINCT DEPT) FROM h2.test.employee")
+    checkAggregateRemoved(df)
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
@@ -469,6 +486,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
 
   test("scan with aggregate push-down: SUM without filer and group by") {
     val df = sql("SELECT SUM(SALARY) FROM h2.test.employee")
+    checkAggregateRemoved(df)
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
@@ -480,6 +498,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
 
   test("scan with aggregate push-down: DISTINCT SUM without filer and group 
by") {
     val df = sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee")
+    checkAggregateRemoved(df)
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
@@ -491,6 +510,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
 
   test("scan with aggregate push-down: SUM with group by") {
     val df = sql("SELECT SUM(SALARY) FROM h2.test.employee GROUP BY DEPT")
+    checkAggregateRemoved(df)
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
@@ -504,6 +524,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
 
   test("scan with aggregate push-down: DISTINCT SUM with group by") {
     val df = sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee GROUP BY 
DEPT")
+    checkAggregateRemoved(df)
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
@@ -518,10 +539,11 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
   test("scan with aggregate push-down: with multiple group by columns") {
     val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where 
dept > 0" +
       " group by DEPT, NAME")
-    val filters11 = df.queryExecution.optimizedPlan.collect {
+    val filters = df.queryExecution.optimizedPlan.collect {
       case f: Filter => f
     }
-    assert(filters11.isEmpty)
+    assert(filters.isEmpty)
+    checkAggregateRemoved(df)
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
@@ -534,6 +556,60 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
       Row(10000, 1000), Row(12000, 1200)))
   }
 
+  test("scan with aggregate push-down: with concat multiple group key in 
project") {
+    val df1 = sql("select concat_ws('#', DEPT, NAME), MAX(SALARY) FROM 
h2.test.employee" +
+      " where dept > 0 group by DEPT, NAME")
+    val filters1 = df1.queryExecution.optimizedPlan.collect {
+      case f: Filter => f
+    }
+    assert(filters1.isEmpty)
+    checkAggregateRemoved(df1)
+    df1.queryExecution.optimizedPlan.collect {
+      case _: DataSourceV2ScanRelation =>
+        val expected_plan_fragment =
+          "PushedAggregates: [MAX(SALARY)], " +
+            "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
+            "PushedGroupByColumns: [DEPT, NAME]"
+        checkKeywordsExistsInExplain(df1, expected_plan_fragment)
+    }
+    checkAnswer(df1, Seq(Row("1#amy", 10000), Row("1#cathy", 9000), 
Row("2#alex", 12000),
+      Row("2#david", 10000), Row("6#jen", 12000)))
+
+    val df2 = sql("select concat_ws('#', DEPT, NAME), MAX(SALARY) + 
MIN(BONUS)" +
+      " FROM h2.test.employee where dept > 0 group by DEPT, NAME")
+    val filters2 = df2.queryExecution.optimizedPlan.collect {
+      case f: Filter => f
+    }
+    assert(filters2.isEmpty)
+    checkAggregateRemoved(df2)
+    df2.queryExecution.optimizedPlan.collect {
+      case _: DataSourceV2ScanRelation =>
+        val expected_plan_fragment =
+          "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " +
+            "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
+            "PushedGroupByColumns: [DEPT, NAME]"
+        checkKeywordsExistsInExplain(df2, expected_plan_fragment)
+    }
+    checkAnswer(df2, Seq(Row("1#amy", 11000), Row("1#cathy", 10200), 
Row("2#alex", 13200),
+      Row("2#david", 11300), Row("6#jen", 13200)))
+
+    val df3 = sql("select concat_ws('#', DEPT, NAME), MAX(SALARY) + 
MIN(BONUS)" +
+      " FROM h2.test.employee where dept > 0 group by concat_ws('#', DEPT, 
NAME)")
+    val filters3 = df3.queryExecution.optimizedPlan.collect {
+      case f: Filter => f
+    }
+    assert(filters3.isEmpty)
+    checkAggregateRemoved(df3, false)
+    df3.queryExecution.optimizedPlan.collect {
+      case _: DataSourceV2ScanRelation =>
+        val expected_plan_fragment =
+          "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], "
+        checkKeywordsExistsInExplain(df3, expected_plan_fragment)
+    }
+    checkAnswer(df3, Seq(Row("1#amy", 11000), Row("1#cathy", 10200), 
Row("2#alex", 13200),
+      Row("2#david", 11300), Row("6#jen", 13200)))
+  }
+
   test("scan with aggregate push-down: with having clause") {
     val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where 
dept > 0" +
       " group by DEPT having MIN(BONUS) > 1000")
@@ -541,6 +617,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
       case f: Filter => f  // filter over aggregate not push down
     }
     assert(filters.nonEmpty)
+    checkAggregateRemoved(df)
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
@@ -556,6 +633,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
     val df = sql("select * from h2.test.employee")
       .groupBy($"DEPT")
       .min("SALARY").as("total")
+    checkAggregateRemoved(df)
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
@@ -579,6 +657,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
       case f: Filter => f
     }
     assert(filters.nonEmpty) // filter over aggregate not pushed down
+    checkAggregateRemoved(df)
     query.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
@@ -594,6 +673,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
     val df = spark.table("h2.test.employee")
     val decrease = udf { (x: Double, y: Double) => x - y }
     val query = df.select(decrease(sum($"SALARY"), sum($"BONUS")).as("value"))
+    checkAggregateRemoved(query)
     query.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
@@ -607,6 +687,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
     val cols = Seq("a", "b", "c", "d")
     val df1 = sql("select * from h2.test.employee").toDF(cols: _*)
     val df2 = df1.groupBy().sum("c")
+    checkAggregateRemoved(df2, false)
     df2.queryExecution.optimizedPlan.collect {
       case relation: DataSourceV2ScanRelation => relation.scan match {
         case v1: V1ScanWrapper =>
@@ -625,6 +706,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
       .filter("SALARY > 100")
       .filter(name($"shortName"))
       .agg(sum($"SALARY").as("sum_salary"))
+    checkAggregateRemoved(query, false)
     query.queryExecution.optimizedPlan.collect {
       case relation: DataSourceV2ScanRelation => relation.scan match {
         case v1: V1ScanWrapper =>
@@ -633,4 +715,17 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
     }
     checkAnswer(query, Seq(Row(29000.0)))
   }
+
+  test("scan with aggregate push-down: SUM(CASE WHEN) with group by") {
+    val df =
+      sql("SELECT SUM(CASE WHEN SALARY > 0 THEN 1 ELSE 0 END) FROM 
h2.test.employee GROUP BY DEPT")
+    checkAggregateRemoved(df, false)
+    df.queryExecution.optimizedPlan.collect {
+      case _: DataSourceV2ScanRelation =>
+        val expected_plan_fragment =
+          "PushedFilters: [], "
+        checkKeywordsExistsInExplain(df, expected_plan_fragment)
+    }
+    checkAnswer(df, Seq(Row(1), Row(2), Row(2)))
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to