huaxingao commented on a change in pull request #29695:
URL: https://github.com/apache/spark/pull/29695#discussion_r493151432



##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
##########
@@ -264,6 +265,19 @@ private[sql] case class JDBCRelation(
     }
   }
 
+  override def unhandledAggregates(aggregates: Array[AggregateFunc]):
+    Array[AggregateFunc] = {
+    if (jdbcOptions.pushDownAggregate) {
+      if (JDBCRDD.compileAggregates(aggregates, 
JdbcDialects.get(jdbcOptions.url)).isEmpty) {
+        aggregates
+      } else {
+        Array.empty[AggregateFunc]
+      }
+    } else {
+      aggregates
+    }
+  }

Review comment:
       will remove this method

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
##########
@@ -17,38 +17,117 @@
 
 package org.apache.spark.sql.execution.datasources.v2
 
-import org.apache.spark.sql.catalyst.expressions.{And, Expression, 
NamedExpression, ProjectionOverSchema, SubqueryExpression}
+import scala.collection.mutable.ArrayBuilder
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.planning.ScanOperation
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, 
Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, 
LogicalPlan, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.connector.read.{Scan, V1Scan}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan}
 import org.apache.spark.sql.execution.datasources.DataSourceStrategy
 import org.apache.spark.sql.sources
+import org.apache.spark.sql.sources.{AggregateFunc, Aggregation}
 import org.apache.spark.sql.types.StructType
 
 object V2ScanRelationPushDown extends Rule[LogicalPlan] {
+
   import DataSourceV2Implicits._
 
   override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
+    case Aggregate(groupingExpressions, resultExpressions, child) =>
+      child match {
+        case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
+          val scanBuilder = 
relation.table.asReadable.newScanBuilder(relation.options)
+          val aggregates = resultExpressions.flatMap { expr =>
+            expr.collect {
+              case agg: AggregateExpression => agg
+            }
+          }.distinct
+
+          val aggregation = PushDownUtils.pushAggregates(scanBuilder, 
aggregates,
+            groupingExpressions)
+
+          val (pushedFilters, postScanFilters, scan, output, 
normalizedProjects) =
+            processFilerAndColumn(scanBuilder, project, filters, relation)
+
+          logInfo(
+            s"""
+               |Pushing operators to ${relation.name}
+               |Pushed Filters: ${pushedFilters.mkString(", ")}
+               |Post-Scan Filters: ${postScanFilters.mkString(",")}
+               |Pushed Aggregate Functions: 
${aggregation.aggregateExpressions.mkString(", ")}
+               |Pushed Groupby: ${aggregation.groupByExpressions.mkString(", 
")}
+               |Output: ${output.mkString(", ")}
+             """.stripMargin)
+
+          val wrappedScan = scan match {
+            case v1: V1Scan =>
+              val translated = 
filters.flatMap(DataSourceStrategy.translateFilter(_, true))
+              V1ScanWrapper(v1, translated, pushedFilters, aggregation)
+            case _ => scan
+          }
+
+          if (aggregation.aggregateExpressions.isEmpty) {
+            val plan = buildLogicalPlan(project, relation, wrappedScan, 
output, normalizedProjects,
+              postScanFilters)
+            Aggregate(groupingExpressions, resultExpressions, plan)
+          } else {
+            val resultAttributes = resultExpressions.map(_.toAttribute)
+              .map ( e => e match { case a: AttributeReference => a })
+            var index = 0
+            val aggOutputBuilder = ArrayBuilder.make[AttributeReference]
+            for (a <- resultAttributes) {
+              aggOutputBuilder +=
+                a.copy(dataType = aggregates(index).dataType)(exprId = 
NamedExpression.newExprId,
+                  qualifier = a.qualifier)
+              index += 1
+            }
+            val aggOutput = aggOutputBuilder.result
+
+            var newOutput = aggOutput
+            for (col <- output) {
+              if (!aggOutput.exists(_.name.contains(col.name))) {
+                newOutput = col +: newOutput
+              }
+            }
+
+            val r = buildLogicalPlan(newOutput, relation, wrappedScan, 
newOutput,
+              normalizedProjects, postScanFilters)

Review comment:
       use the new output to build the plan. 
   e.g. original output, `JDBCScan$$anon$1@226de93c [ID#18]`, new output, 
`JDBCScan$$anon$1@3f6f9cef [max(ID)#77,min(ID)#78]`

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
##########
@@ -17,38 +17,117 @@
 
 package org.apache.spark.sql.execution.datasources.v2
 
-import org.apache.spark.sql.catalyst.expressions.{And, Expression, 
NamedExpression, ProjectionOverSchema, SubqueryExpression}
+import scala.collection.mutable.ArrayBuilder
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.planning.ScanOperation
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, 
Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, 
LogicalPlan, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.connector.read.{Scan, V1Scan}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan}
 import org.apache.spark.sql.execution.datasources.DataSourceStrategy
 import org.apache.spark.sql.sources
+import org.apache.spark.sql.sources.{AggregateFunc, Aggregation}
 import org.apache.spark.sql.types.StructType
 
 object V2ScanRelationPushDown extends Rule[LogicalPlan] {
+
   import DataSourceV2Implicits._
 
   override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
+    case Aggregate(groupingExpressions, resultExpressions, child) =>
+      child match {
+        case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
+          val scanBuilder = 
relation.table.asReadable.newScanBuilder(relation.options)
+          val aggregates = resultExpressions.flatMap { expr =>
+            expr.collect {
+              case agg: AggregateExpression => agg
+            }
+          }.distinct
+
+          val aggregation = PushDownUtils.pushAggregates(scanBuilder, 
aggregates,
+            groupingExpressions)
+
+          val (pushedFilters, postScanFilters, scan, output, 
normalizedProjects) =
+            processFilerAndColumn(scanBuilder, project, filters, relation)
+
+          logInfo(
+            s"""
+               |Pushing operators to ${relation.name}
+               |Pushed Filters: ${pushedFilters.mkString(", ")}
+               |Post-Scan Filters: ${postScanFilters.mkString(",")}
+               |Pushed Aggregate Functions: 
${aggregation.aggregateExpressions.mkString(", ")}
+               |Pushed Groupby: ${aggregation.groupByExpressions.mkString(", 
")}
+               |Output: ${output.mkString(", ")}
+             """.stripMargin)
+
+          val wrappedScan = scan match {
+            case v1: V1Scan =>
+              val translated = 
filters.flatMap(DataSourceStrategy.translateFilter(_, true))
+              V1ScanWrapper(v1, translated, pushedFilters, aggregation)
+            case _ => scan
+          }
+
+          if (aggregation.aggregateExpressions.isEmpty) {
+            val plan = buildLogicalPlan(project, relation, wrappedScan, 
output, normalizedProjects,
+              postScanFilters)
+            Aggregate(groupingExpressions, resultExpressions, plan)
+          } else {
+            val resultAttributes = resultExpressions.map(_.toAttribute)
+              .map ( e => e match { case a: AttributeReference => a })
+            var index = 0
+            val aggOutputBuilder = ArrayBuilder.make[AttributeReference]
+            for (a <- resultAttributes) {
+              aggOutputBuilder +=
+                a.copy(dataType = aggregates(index).dataType)(exprId = 
NamedExpression.newExprId,
+                  qualifier = a.qualifier)
+              index += 1
+            }
+            val aggOutput = aggOutputBuilder.result
+
+            var newOutput = aggOutput
+            for (col <- output) {
+              if (!aggOutput.exists(_.name.contains(col.name))) {
+                newOutput = col +: newOutput
+              }
+            }
+
+            val r = buildLogicalPlan(newOutput, relation, wrappedScan, 
newOutput,
+              normalizedProjects, postScanFilters)
+            val plan = Aggregate(groupingExpressions, resultExpressions, r)
+
+            var i = 0
+            plan.transformExpressions {
+            case agg: AggregateExpression =>
+              val aggFunction: aggregate.AggregateFunction = {
+                i += 1
+                if (agg.aggregateFunction.isInstanceOf[aggregate.Max]) {
+                  aggregate.Max(aggOutput(i - 1))
+                } else if (agg.aggregateFunction.isInstanceOf[aggregate.Min]) {
+                  aggregate.Min(aggOutput(i - 1))
+                } else if 
(agg.aggregateFunction.isInstanceOf[aggregate.Average]) {
+                  aggregate.Average(aggOutput(i - 1))
+                } else if (agg.aggregateFunction.isInstanceOf[aggregate.Sum]) {
+                  aggregate.Sum(aggOutput(i - 1))
+                } else {
+                  agg.aggregateFunction
+                }
+              }
+              agg.transform {
+                case a: aggregate.AggregateFunction => aggFunction
+              }

Review comment:
       update optimized logical plan 
   before update
   ```
   Aggregate [max(id#18) AS max(id)#21, min(id#18) AS min(id)#22]
   +- RelationV2[ID#18] test.people
   ```
   after update
   ```
   Aggregate [max(max(ID)#77) AS max(ID)#72, min(min(ID)#78) AS min(ID)#73]
   +- RelationV2[max(ID)#77, min(ID)#78] test.people
   ```

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
##########
@@ -232,6 +232,17 @@ abstract class BaseRelation {
    * @since 1.6.0
    */
   def unhandledFilters(filters: Array[Filter]): Array[Filter] = filters
+
+  /**
+   * Returns the list of [[Aggregate]]s that this datasource may not be able 
to handle.
+   * These returned [[Aggregate]]s will be evaluated by Spark SQL after data 
is output by a scan.
+   * By default, this function will return all aggregates, as it is always 
safe to
+   * double evaluate a [[Aggregate]].
+   *
+   * @since 3.1.0
+   */
+  def unhandledAggregates(aggregates: Array[AggregateFunc]): 
Array[AggregateFunc] =
+    aggregates

Review comment:
       This is not used. Will remove




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to