sunchao commented on a change in pull request #29695:
URL: https://github.com/apache/spark/pull/29695#discussion_r606460240



##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
##########
@@ -700,6 +704,49 @@ object DataSourceStrategy
     (nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, 
handledFilters)
   }
 
+  private def columnAsString(e: Expression): String = e match {
+    case AttributeReference(name, _, _, _) => name
+    case Cast(child, _, _) => columnAsString (child)
+    case Add(left, right, _) =>
+      columnAsString(left) + " + " + columnAsString(right)
+    case Subtract(left, right, _) =>
+      columnAsString(left) + " - " + columnAsString(right)
+    case Multiply(left, right, _) =>
+      columnAsString(left) + " * " + columnAsString(right)
+    case Divide(left, right, _) =>
+      columnAsString(left) + " / " + columnAsString(right)
+    case CheckOverflow(child, _, _) => columnAsString (child)

Review comment:
       nit: extra space after `columnAsString`.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
##########
@@ -17,38 +17,132 @@
 
 package org.apache.spark.sql.execution.datasources.v2
 
-import org.apache.spark.sql.catalyst.expressions.{And, Expression, 
NamedExpression, ProjectionOverSchema, SubqueryExpression}
+import scala.collection.mutable.ArrayBuilder
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.planning.ScanOperation
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, 
Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, 
LogicalPlan, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.connector.read.{Scan, V1Scan}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan}
 import org.apache.spark.sql.execution.datasources.DataSourceStrategy
 import org.apache.spark.sql.sources
+import org.apache.spark.sql.sources.{AggregateFunc, Aggregation}
 import org.apache.spark.sql.types.StructType
 
-object V2ScanRelationPushDown extends Rule[LogicalPlan] {
+object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper {
+
   import DataSourceV2Implicits._
 
   override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
-    case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
-      val scanBuilder = 
relation.table.asReadable.newScanBuilder(relation.options)
+    case Aggregate(groupingExpressions, resultExpressions, child) =>

Review comment:
       This is a little hard to read. Maybe we can better separate the logic 
for pushing down aggregate with pushing down filters. Also some comments can 
help.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
##########
@@ -17,38 +17,132 @@
 
 package org.apache.spark.sql.execution.datasources.v2
 
-import org.apache.spark.sql.catalyst.expressions.{And, Expression, 
NamedExpression, ProjectionOverSchema, SubqueryExpression}
+import scala.collection.mutable.ArrayBuilder
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.planning.ScanOperation
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, 
Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, 
LogicalPlan, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.connector.read.{Scan, V1Scan}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan}
 import org.apache.spark.sql.execution.datasources.DataSourceStrategy
 import org.apache.spark.sql.sources
+import org.apache.spark.sql.sources.{AggregateFunc, Aggregation}
 import org.apache.spark.sql.types.StructType
 
-object V2ScanRelationPushDown extends Rule[LogicalPlan] {
+object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper {
+
   import DataSourceV2Implicits._
 
   override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
-    case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
-      val scanBuilder = 
relation.table.asReadable.newScanBuilder(relation.options)
+    case Aggregate(groupingExpressions, resultExpressions, child) =>
+      child match {
+        case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
+          val scanBuilder = 
relation.table.asReadable.newScanBuilder(relation.options)
 
-      val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, 
relation.output)
-      val (normalizedFiltersWithSubquery, normalizedFiltersWithoutSubquery) =
-        normalizedFilters.partition(SubqueryExpression.hasSubquery)
+          val aliasMap = getAliasMap(project)
+          var aggregates = resultExpressions.flatMap { expr =>
+            expr.collect {
+              case agg: AggregateExpression =>
+                replaceAlias(agg, aliasMap).asInstanceOf[AggregateExpression]
+            }
+          }
+          aggregates = DataSourceStrategy.normalizeExprs(aggregates, 
relation.output)
+            .asInstanceOf[Seq[AggregateExpression]]
 
-      // `pushedFilters` will be pushed down and evaluated in the underlying 
data sources.
-      // `postScanFilters` need to be evaluated after the scan.
-      // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet 
row group filter.
-      val (pushedFilters, postScanFiltersWithoutSubquery) = 
PushDownUtils.pushFilters(
-        scanBuilder, normalizedFiltersWithoutSubquery)
-      val postScanFilters = postScanFiltersWithoutSubquery ++ 
normalizedFiltersWithSubquery
+          val groupingExpressionsWithoutAlias = groupingExpressions.flatMap{ 
expr =>
+            expr.collect {
+              case e: Expression => replaceAlias(e, aliasMap)
+            }
+          }
+          val normalizedGroupingExpressions =
+            DataSourceStrategy.normalizeExprs(groupingExpressionsWithoutAlias, 
relation.output)
+
+          var newFilters = filters
+          aggregates.foreach(agg =>
+            if (agg.filter.nonEmpty)  {
+              // handle agg filter the same way as other filters
+              newFilters = newFilters :+ agg.filter.get
+            }
+          )
+
+          val (pushedFilters, postScanFilters) = pushDownFilter(scanBuilder, 
newFilters, relation)
+          if (postScanFilters.nonEmpty) {
+            Aggregate(groupingExpressions, resultExpressions, child)

Review comment:
       perhaps we should return the original plan node rather than a new 
`Aggregate`?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
##########
@@ -133,6 +135,68 @@ object JDBCRDD extends Logging {
     })
   }
 
+  private def containsArithmeticOp(col: String): Boolean =
+    col.contains("+") || col.contains("-") || col.contains("*") || 
col.contains("/")
+
+  def compileAggregates(
+      aggregates: Seq[AggregateFunc],
+      dialect: JdbcDialect): (Array[String], Array[DataType]) = {
+    def quote(colName: String): String = dialect.quoteIdentifier(colName)
+    val aggBuilder = ArrayBuilder.make[String]
+    val dataTypeBuilder = ArrayBuilder.make[DataType]
+    aggregates.map {
+      case Min(column, dataType) =>
+        dataTypeBuilder += dataType
+        if (!containsArithmeticOp(column)) {
+          aggBuilder += s"MIN(${quote(column)})"
+        } else {
+          aggBuilder += s"MIN(${quoteEachCols(column, dialect)})"
+        }
+      case Max(column, dataType) =>
+        dataTypeBuilder += dataType
+        if (!containsArithmeticOp(column)) {
+          aggBuilder += s"MAX(${quote(column)})"
+        } else {
+          aggBuilder += s"MAX(${quoteEachCols(column, dialect)})"
+        }
+      case Sum(column, dataType, isDistinct) =>
+        val distinct = if (isDistinct) "DISTINCT " else ""
+        dataTypeBuilder += dataType
+        if (!containsArithmeticOp(column)) {
+          aggBuilder += s"SUM(${distinct} ${quote(column)})"
+        } else {
+          aggBuilder += s"SUM(${distinct} ${quoteEachCols(column, dialect)})"
+        }
+      case Avg(column, dataType, isDistinct) =>
+        val distinct = if (isDistinct) "DISTINCT " else ""
+        dataTypeBuilder += dataType
+        if (!containsArithmeticOp(column)) {
+          aggBuilder += s"AVG(${distinct} ${quote(column)})"
+        } else {
+          aggBuilder += s"AVG(${distinct} ${quoteEachCols(column, dialect)})"
+        }
+      case Count(column, dataType, isDistinct) =>
+        val distinct = if (isDistinct) "DISTINCT " else ""
+        dataTypeBuilder += dataType
+        val col = if (column.equals("1")) column else quote(column)
+          aggBuilder += s"COUNT(${distinct} $col)"
+      case _ =>
+    }
+    (aggBuilder.result, dataTypeBuilder.result)
+  }
+
+  private def quoteEachCols (column: String, dialect: JdbcDialect): String = {

Review comment:
       nit: extra space

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/sources/aggregates.scala
##########
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources
+
+import org.apache.spark.sql.types.DataType
+
+case class Aggregation(aggregateExpressions: Seq[AggregateFunc],

Review comment:
       I think these need some docs since they are user-facing? and maybe some 
examples on how to handle `aggregateExpressions` and `groupByExpressions`. For 
the latter, should we also name it `groupByColumns`? 

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
##########
@@ -17,38 +17,132 @@
 
 package org.apache.spark.sql.execution.datasources.v2
 
-import org.apache.spark.sql.catalyst.expressions.{And, Expression, 
NamedExpression, ProjectionOverSchema, SubqueryExpression}
+import scala.collection.mutable.ArrayBuilder
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.planning.ScanOperation
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, 
Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, 
LogicalPlan, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.connector.read.{Scan, V1Scan}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan}
 import org.apache.spark.sql.execution.datasources.DataSourceStrategy
 import org.apache.spark.sql.sources
+import org.apache.spark.sql.sources.{AggregateFunc, Aggregation}
 import org.apache.spark.sql.types.StructType
 
-object V2ScanRelationPushDown extends Rule[LogicalPlan] {
+object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper {
+
   import DataSourceV2Implicits._
 
   override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
-    case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
-      val scanBuilder = 
relation.table.asReadable.newScanBuilder(relation.options)
+    case Aggregate(groupingExpressions, resultExpressions, child) =>
+      child match {
+        case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
+          val scanBuilder = 
relation.table.asReadable.newScanBuilder(relation.options)
 
-      val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, 
relation.output)
-      val (normalizedFiltersWithSubquery, normalizedFiltersWithoutSubquery) =
-        normalizedFilters.partition(SubqueryExpression.hasSubquery)
+          val aliasMap = getAliasMap(project)
+          var aggregates = resultExpressions.flatMap { expr =>
+            expr.collect {
+              case agg: AggregateExpression =>
+                replaceAlias(agg, aliasMap).asInstanceOf[AggregateExpression]
+            }
+          }
+          aggregates = DataSourceStrategy.normalizeExprs(aggregates, 
relation.output)
+            .asInstanceOf[Seq[AggregateExpression]]
 
-      // `pushedFilters` will be pushed down and evaluated in the underlying 
data sources.
-      // `postScanFilters` need to be evaluated after the scan.
-      // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet 
row group filter.
-      val (pushedFilters, postScanFiltersWithoutSubquery) = 
PushDownUtils.pushFilters(
-        scanBuilder, normalizedFiltersWithoutSubquery)
-      val postScanFilters = postScanFiltersWithoutSubquery ++ 
normalizedFiltersWithSubquery
+          val groupingExpressionsWithoutAlias = groupingExpressions.flatMap{ 
expr =>
+            expr.collect {
+              case e: Expression => replaceAlias(e, aliasMap)
+            }
+          }
+          val normalizedGroupingExpressions =
+            DataSourceStrategy.normalizeExprs(groupingExpressionsWithoutAlias, 
relation.output)
+
+          var newFilters = filters
+          aggregates.foreach(agg =>
+            if (agg.filter.nonEmpty)  {
+              // handle agg filter the same way as other filters
+              newFilters = newFilters :+ agg.filter.get
+            }
+          )
+
+          val (pushedFilters, postScanFilters) = pushDownFilter(scanBuilder, 
newFilters, relation)
+          if (postScanFilters.nonEmpty) {
+            Aggregate(groupingExpressions, resultExpressions, child)
+          } else { // only push down aggregate if all the filers can be push 
down
+            val aggregation = PushDownUtils.pushAggregates(scanBuilder, 
aggregates,
+              normalizedGroupingExpressions)
+
+            val (scan, output, normalizedProjects) =
+              processFilterAndColumn(scanBuilder, project, postScanFilters, 
relation)
+
+            logInfo(
+              s"""
+                 |Pushing operators to ${relation.name}
+                 |Pushed Filters: ${pushedFilters.mkString(", ")}
+                 |Post-Scan Filters: ${postScanFilters.mkString(",")}
+                 |Pushed Aggregate Functions: 
${aggregation.aggregateExpressions.mkString(", ")}
+                 |Pushed Groupby: ${aggregation.groupByExpressions.mkString(", 
")}
+                 |Output: ${output.mkString(", ")}
+             """.stripMargin)
+
+            val wrappedScan = scan match {
+              case v1: V1Scan =>
+                val translated = 
newFilters.flatMap(DataSourceStrategy.translateFilter(_, true))
+                V1ScanWrapper(v1, translated, pushedFilters, aggregation)
+              case _ => scan
+            }
+
+            if (aggregation.aggregateExpressions.isEmpty) {
+              Aggregate(groupingExpressions, resultExpressions, child)

Review comment:
       ditto

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
##########
@@ -700,6 +704,49 @@ object DataSourceStrategy
     (nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, 
handledFilters)
   }
 
+  private def columnAsString(e: Expression): String = e match {
+    case AttributeReference(name, _, _, _) => name
+    case Cast(child, _, _) => columnAsString (child)

Review comment:
       nit: extra space after `columnAsString`.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
##########
@@ -700,6 +704,49 @@ object DataSourceStrategy
     (nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, 
handledFilters)
   }
 
+  private def columnAsString(e: Expression): String = e match {
+    case AttributeReference(name, _, _, _) => name
+    case Cast(child, _, _) => columnAsString (child)
+    case Add(left, right, _) =>
+      columnAsString(left) + " + " + columnAsString(right)
+    case Subtract(left, right, _) =>
+      columnAsString(left) + " - " + columnAsString(right)
+    case Multiply(left, right, _) =>
+      columnAsString(left) + " * " + columnAsString(right)
+    case Divide(left, right, _) =>
+      columnAsString(left) + " / " + columnAsString(right)
+    case CheckOverflow(child, _, _) => columnAsString (child)
+    case PromotePrecision(child) => columnAsString (child)
+    case _ => ""
+  }
+
+  protected[sql] def translateAggregate(aggregates: AggregateExpression): 
Option[AggregateFunc] = {
+    aggregates.aggregateFunction match {
+      case min: aggregate.Min =>
+        val colName = columnAsString(min.child)
+        if (colName.nonEmpty) Some(Min(colName, min.dataType)) else None
+      case max: aggregate.Max =>
+        val colName = columnAsString(max.child)
+        if (colName.nonEmpty) Some(Max(colName, max.dataType)) else None
+      case avg: aggregate.Average =>
+        val colName = columnAsString(avg.child)
+        if (colName.nonEmpty) Some(Avg(colName, avg.dataType, 
aggregates.isDistinct)) else None
+      case sum: aggregate.Sum =>
+        val colName = columnAsString(sum.child)
+        if (colName.nonEmpty) Some(Sum(colName, sum.dataType, 
aggregates.isDistinct)) else None
+      case count: aggregate.Count =>
+        val columnName = count.children.head match {
+          case Literal(_, _) => "1"

Review comment:
       why this is "1"? also should we check if there is more than one elements 
in `children`?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
##########
@@ -700,6 +704,41 @@ object DataSourceStrategy
     (nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, 
handledFilters)
   }
 
+  private def columnAsString(e: Expression): String = e match {

Review comment:
       +1. It also seems strange to convert binary expression into a "magic" 
string form that is (seems) special to JDBC datasources.  
   
   I also wonder if we should handle nested columns the same way as 
`PushableColumnBase`

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
##########
@@ -70,6 +72,43 @@ object PushDownUtils extends PredicateHelper {
     }
   }
 
+    /**
+     * Pushes down aggregates to the data source reader
+     *
+     * @return pushed aggregation.
+     */
+    def pushAggregates(
+        scanBuilder: ScanBuilder,
+        aggregates: Seq[AggregateExpression],
+        groupBy: Seq[Expression]): Aggregation = {
+
+      def columnAsString(e: Expression): String = e match {
+        case AttributeReference(name, _, _, _) => name
+        case _ => ""
+      }
+
+      scanBuilder match {
+        case r: SupportsPushDownAggregates =>
+          val translatedAggregates = 
mutable.ArrayBuffer.empty[sources.AggregateFunc]
+
+          for (aggregateExpr <- aggregates) {
+            val translated = 
DataSourceStrategy.translateAggregate(aggregateExpr)
+            if (translated.isEmpty) {
+              return Aggregation.empty
+            } else {
+              translatedAggregates += translated.get
+            }
+          }
+          val groupByCols = groupBy.map(columnAsString(_))
+          if (!groupByCols.exists(_.isEmpty)) {
+            r.pushAggregation(Aggregation(translatedAggregates, groupByCols))
+          }

Review comment:
       what about "else" branch? perhaps we should revise this code to:
   ```scala
           case r: SupportsPushDownAggregates =>
             val translatedAggregates = 
aggregates.map(DataSourceStrategy.translateAggregate)
             val translatedGroupBys = groupBy.map(columnAsString)
   
             if (translatedAggregates.exists(_.isEmpty) || 
translatedGroupBys.exists(_.isEmpty)) {
               Aggregation.empty
             } else {
               r.pushAggregation(Aggregation(translatedAggregates.flatten, 
translatedGroupBys))
               r.pushedAggregation
             }
   ```

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
##########
@@ -17,38 +17,132 @@
 
 package org.apache.spark.sql.execution.datasources.v2
 
-import org.apache.spark.sql.catalyst.expressions.{And, Expression, 
NamedExpression, ProjectionOverSchema, SubqueryExpression}
+import scala.collection.mutable.ArrayBuilder
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.planning.ScanOperation
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, 
Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, 
LogicalPlan, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.connector.read.{Scan, V1Scan}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan}
 import org.apache.spark.sql.execution.datasources.DataSourceStrategy
 import org.apache.spark.sql.sources
+import org.apache.spark.sql.sources.{AggregateFunc, Aggregation}
 import org.apache.spark.sql.types.StructType
 
-object V2ScanRelationPushDown extends Rule[LogicalPlan] {
+object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper {
+
   import DataSourceV2Implicits._
 
   override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
-    case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
-      val scanBuilder = 
relation.table.asReadable.newScanBuilder(relation.options)
+    case Aggregate(groupingExpressions, resultExpressions, child) =>
+      child match {
+        case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
+          val scanBuilder = 
relation.table.asReadable.newScanBuilder(relation.options)
 
-      val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, 
relation.output)
-      val (normalizedFiltersWithSubquery, normalizedFiltersWithoutSubquery) =
-        normalizedFilters.partition(SubqueryExpression.hasSubquery)
+          val aliasMap = getAliasMap(project)
+          var aggregates = resultExpressions.flatMap { expr =>
+            expr.collect {
+              case agg: AggregateExpression =>
+                replaceAlias(agg, aliasMap).asInstanceOf[AggregateExpression]
+            }
+          }
+          aggregates = DataSourceStrategy.normalizeExprs(aggregates, 
relation.output)
+            .asInstanceOf[Seq[AggregateExpression]]
 
-      // `pushedFilters` will be pushed down and evaluated in the underlying 
data sources.
-      // `postScanFilters` need to be evaluated after the scan.
-      // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet 
row group filter.
-      val (pushedFilters, postScanFiltersWithoutSubquery) = 
PushDownUtils.pushFilters(
-        scanBuilder, normalizedFiltersWithoutSubquery)
-      val postScanFilters = postScanFiltersWithoutSubquery ++ 
normalizedFiltersWithSubquery
+          val groupingExpressionsWithoutAlias = groupingExpressions.flatMap{ 
expr =>
+            expr.collect {
+              case e: Expression => replaceAlias(e, aliasMap)
+            }
+          }
+          val normalizedGroupingExpressions =
+            DataSourceStrategy.normalizeExprs(groupingExpressionsWithoutAlias, 
relation.output)
+
+          var newFilters = filters
+          aggregates.foreach(agg =>
+            if (agg.filter.nonEmpty)  {
+              // handle agg filter the same way as other filters
+              newFilters = newFilters :+ agg.filter.get
+            }
+          )
+
+          val (pushedFilters, postScanFilters) = pushDownFilter(scanBuilder, 
newFilters, relation)
+          if (postScanFilters.nonEmpty) {
+            Aggregate(groupingExpressions, resultExpressions, child)
+          } else { // only push down aggregate if all the filers can be push 
down
+            val aggregation = PushDownUtils.pushAggregates(scanBuilder, 
aggregates,
+              normalizedGroupingExpressions)
+
+            val (scan, output, normalizedProjects) =
+              processFilterAndColumn(scanBuilder, project, postScanFilters, 
relation)
+
+            logInfo(
+              s"""
+                 |Pushing operators to ${relation.name}
+                 |Pushed Filters: ${pushedFilters.mkString(", ")}
+                 |Post-Scan Filters: ${postScanFilters.mkString(",")}
+                 |Pushed Aggregate Functions: 
${aggregation.aggregateExpressions.mkString(", ")}
+                 |Pushed Groupby: ${aggregation.groupByExpressions.mkString(", 
")}
+                 |Output: ${output.mkString(", ")}
+             """.stripMargin)
+
+            val wrappedScan = scan match {
+              case v1: V1Scan =>
+                val translated = 
newFilters.flatMap(DataSourceStrategy.translateFilter(_, true))
+                V1ScanWrapper(v1, translated, pushedFilters, aggregation)
+              case _ => scan
+            }
+
+            if (aggregation.aggregateExpressions.isEmpty) {
+              Aggregate(groupingExpressions, resultExpressions, child)
+            } else {
+              val aggOutputBuilder = ArrayBuilder.make[AttributeReference]
+              for (i <- 0 until aggregates.length) {
+                aggOutputBuilder += AttributeReference(
+                  aggregation.aggregateExpressions(i).toString, 
aggregates(i).dataType)()

Review comment:
       hmm is this correct?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
##########
@@ -17,38 +17,132 @@
 
 package org.apache.spark.sql.execution.datasources.v2
 
-import org.apache.spark.sql.catalyst.expressions.{And, Expression, 
NamedExpression, ProjectionOverSchema, SubqueryExpression}
+import scala.collection.mutable.ArrayBuilder
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.planning.ScanOperation
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, 
Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, 
LogicalPlan, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.connector.read.{Scan, V1Scan}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan}
 import org.apache.spark.sql.execution.datasources.DataSourceStrategy
 import org.apache.spark.sql.sources
+import org.apache.spark.sql.sources.{AggregateFunc, Aggregation}
 import org.apache.spark.sql.types.StructType
 
-object V2ScanRelationPushDown extends Rule[LogicalPlan] {
+object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper {
+
   import DataSourceV2Implicits._
 
   override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
-    case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
-      val scanBuilder = 
relation.table.asReadable.newScanBuilder(relation.options)
+    case Aggregate(groupingExpressions, resultExpressions, child) =>
+      child match {
+        case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
+          val scanBuilder = 
relation.table.asReadable.newScanBuilder(relation.options)
 
-      val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, 
relation.output)
-      val (normalizedFiltersWithSubquery, normalizedFiltersWithoutSubquery) =
-        normalizedFilters.partition(SubqueryExpression.hasSubquery)
+          val aliasMap = getAliasMap(project)
+          var aggregates = resultExpressions.flatMap { expr =>
+            expr.collect {
+              case agg: AggregateExpression =>
+                replaceAlias(agg, aliasMap).asInstanceOf[AggregateExpression]
+            }
+          }
+          aggregates = DataSourceStrategy.normalizeExprs(aggregates, 
relation.output)
+            .asInstanceOf[Seq[AggregateExpression]]
 
-      // `pushedFilters` will be pushed down and evaluated in the underlying 
data sources.
-      // `postScanFilters` need to be evaluated after the scan.
-      // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet 
row group filter.
-      val (pushedFilters, postScanFiltersWithoutSubquery) = 
PushDownUtils.pushFilters(
-        scanBuilder, normalizedFiltersWithoutSubquery)
-      val postScanFilters = postScanFiltersWithoutSubquery ++ 
normalizedFiltersWithSubquery
+          val groupingExpressionsWithoutAlias = groupingExpressions.flatMap{ 
expr =>
+            expr.collect {
+              case e: Expression => replaceAlias(e, aliasMap)
+            }
+          }
+          val normalizedGroupingExpressions =
+            DataSourceStrategy.normalizeExprs(groupingExpressionsWithoutAlias, 
relation.output)
+
+          var newFilters = filters
+          aggregates.foreach(agg =>
+            if (agg.filter.nonEmpty)  {
+              // handle agg filter the same way as other filters
+              newFilters = newFilters :+ agg.filter.get
+            }
+          )
+
+          val (pushedFilters, postScanFilters) = pushDownFilter(scanBuilder, 
newFilters, relation)
+          if (postScanFilters.nonEmpty) {
+            Aggregate(groupingExpressions, resultExpressions, child)
+          } else { // only push down aggregate if all the filers can be push 
down
+            val aggregation = PushDownUtils.pushAggregates(scanBuilder, 
aggregates,
+              normalizedGroupingExpressions)
+
+            val (scan, output, normalizedProjects) =
+              processFilterAndColumn(scanBuilder, project, postScanFilters, 
relation)
+
+            logInfo(
+              s"""
+                 |Pushing operators to ${relation.name}
+                 |Pushed Filters: ${pushedFilters.mkString(", ")}
+                 |Post-Scan Filters: ${postScanFilters.mkString(",")}
+                 |Pushed Aggregate Functions: 
${aggregation.aggregateExpressions.mkString(", ")}
+                 |Pushed Groupby: ${aggregation.groupByExpressions.mkString(", 
")}
+                 |Output: ${output.mkString(", ")}
+             """.stripMargin)
+
+            val wrappedScan = scan match {
+              case v1: V1Scan =>
+                val translated = 
newFilters.flatMap(DataSourceStrategy.translateFilter(_, true))
+                V1ScanWrapper(v1, translated, pushedFilters, aggregation)
+              case _ => scan
+            }
+
+            if (aggregation.aggregateExpressions.isEmpty) {
+              Aggregate(groupingExpressions, resultExpressions, child)
+            } else {
+              val aggOutputBuilder = ArrayBuilder.make[AttributeReference]
+              for (i <- 0 until aggregates.length) {
+                aggOutputBuilder += AttributeReference(
+                  aggregation.aggregateExpressions(i).toString, 
aggregates(i).dataType)()
+              }
+              groupingExpressions.foreach{
+                case a@AttributeReference(_, _, _, _) => aggOutputBuilder += a
+                case _ =>
+              }
+              val aggOutput = aggOutputBuilder.result
+
+              val r = buildLogicalPlan(aggOutput, relation, wrappedScan, 
aggOutput,
+                normalizedProjects, postScanFilters)
+              val plan = Aggregate(groupingExpressions, resultExpressions, r)
+
+              var i = 0
+              plan.transformExpressions {
+                case agg: AggregateExpression =>
+                  i += 1
+                  val aggFunction: aggregate.AggregateFunction = {
+                    if (agg.aggregateFunction.isInstanceOf[aggregate.Max]) {

Review comment:
       pattern matching?

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
##########
@@ -273,6 +273,16 @@ trait PrunedFilteredScan {
   def buildScan(requiredColumns: Array[String], filters: Array[Filter]): 
RDD[Row]
 }
 
+/**
+ * @since 3.1.0
+ */
+trait PrunedFilteredAggregateScan {

Review comment:
       it's a bit strange that this is a DSv1 API but is only used by DSv2 JDBC 
scan? is it possible that a V1 data source implements this and goes through the 
V1 code path (i.e., through `DataSourceStrategy`)?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to