huaxingao commented on a change in pull request #29695:
URL: https://github.com/apache/spark/pull/29695#discussion_r493151656
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
##########
@@ -17,38 +17,117 @@
package org.apache.spark.sql.execution.datasources.v2
-import org.apache.spark.sql.catalyst.expressions.{And, Expression,
NamedExpression, ProjectionOverSchema, SubqueryExpression}
+import scala.collection.mutable.ArrayBuilder
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning.ScanOperation
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan,
Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter,
LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.connector.read.{Scan, V1Scan}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.sources
+import org.apache.spark.sql.sources.{AggregateFunc, Aggregation}
import org.apache.spark.sql.types.StructType
object V2ScanRelationPushDown extends Rule[LogicalPlan] {
+
import DataSourceV2Implicits._
override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
+ case Aggregate(groupingExpressions, resultExpressions, child) =>
+ child match {
+ case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
+ val scanBuilder =
relation.table.asReadable.newScanBuilder(relation.options)
+ val aggregates = resultExpressions.flatMap { expr =>
+ expr.collect {
+ case agg: AggregateExpression => agg
+ }
+ }.distinct
+
+ val aggregation = PushDownUtils.pushAggregates(scanBuilder,
aggregates,
+ groupingExpressions)
+
+ val (pushedFilters, postScanFilters, scan, output,
normalizedProjects) =
+ processFilerAndColumn(scanBuilder, project, filters, relation)
+
+ logInfo(
+ s"""
+ |Pushing operators to ${relation.name}
+ |Pushed Filters: ${pushedFilters.mkString(", ")}
+ |Post-Scan Filters: ${postScanFilters.mkString(",")}
+ |Pushed Aggregate Functions:
${aggregation.aggregateExpressions.mkString(", ")}
+ |Pushed Groupby: ${aggregation.groupByExpressions.mkString(",
")}
+ |Output: ${output.mkString(", ")}
+ """.stripMargin)
+
+ val wrappedScan = scan match {
+ case v1: V1Scan =>
+ val translated =
filters.flatMap(DataSourceStrategy.translateFilter(_, true))
+ V1ScanWrapper(v1, translated, pushedFilters, aggregation)
+ case _ => scan
+ }
+
+ if (aggregation.aggregateExpressions.isEmpty) {
+ val plan = buildLogicalPlan(project, relation, wrappedScan,
output, normalizedProjects,
+ postScanFilters)
+ Aggregate(groupingExpressions, resultExpressions, plan)
+ } else {
+ val resultAttributes = resultExpressions.map(_.toAttribute)
+ .map ( e => e match { case a: AttributeReference => a })
+ var index = 0
+ val aggOutputBuilder = ArrayBuilder.make[AttributeReference]
+ for (a <- resultAttributes) {
+ aggOutputBuilder +=
+ a.copy(dataType = aggregates(index).dataType)(exprId =
NamedExpression.newExprId,
+ qualifier = a.qualifier)
+ index += 1
+ }
+ val aggOutput = aggOutputBuilder.result
+
+ var newOutput = aggOutput
+ for (col <- output) {
+ if (!aggOutput.exists(_.name.contains(col.name))) {
+ newOutput = col +: newOutput
+ }
+ }
+
+ val r = buildLogicalPlan(newOutput, relation, wrappedScan,
newOutput,
+ normalizedProjects, postScanFilters)
+ val plan = Aggregate(groupingExpressions, resultExpressions, r)
+
+ var i = 0
+ plan.transformExpressions {
+ case agg: AggregateExpression =>
+ val aggFunction: aggregate.AggregateFunction = {
+ i += 1
+ if (agg.aggregateFunction.isInstanceOf[aggregate.Max]) {
+ aggregate.Max(aggOutput(i - 1))
+ } else if (agg.aggregateFunction.isInstanceOf[aggregate.Min]) {
+ aggregate.Min(aggOutput(i - 1))
+ } else if
(agg.aggregateFunction.isInstanceOf[aggregate.Average]) {
+ aggregate.Average(aggOutput(i - 1))
+ } else if (agg.aggregateFunction.isInstanceOf[aggregate.Sum]) {
+ aggregate.Sum(aggOutput(i - 1))
+ } else {
+ agg.aggregateFunction
+ }
+ }
+ agg.transform {
+ case a: aggregate.AggregateFunction => aggFunction
+ }
Review comment:
update optimized logical plan
before update
```
Aggregate [max(id#18) AS max(id)#21, min(id#18) AS min(id)#22]
+- RelationV2[ID#18] test.people
```
after update
```
Aggregate [max(max(ID)#77) AS max(ID)#72, min(min(ID)#78) AS min(ID)#73]
+- RelationV2[max(ID)#77, min(ID)#78] test.people
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]