cloud-fan commented on a change in pull request #33352:
URL: https://github.com/apache/spark/pull/33352#discussion_r670690954



##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
##########
@@ -17,61 +17,193 @@
 
 package org.apache.spark.sql.execution.datasources.v2
 
-import org.apache.spark.sql.catalyst.expressions.{And, Expression, 
NamedExpression, ProjectionOverSchema, SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute, 
AttributeReference, Expression, NamedExpression, PredicateHelper, 
ProjectionOverSchema, SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions.aggregate
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.planning.ScanOperation
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, 
Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, 
LeafNode, LogicalPlan, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.connector.read.{Scan, V1Scan}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, 
SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
 import org.apache.spark.sql.execution.datasources.DataSourceStrategy
 import org.apache.spark.sql.sources
 import org.apache.spark.sql.types.StructType
 
-object V2ScanRelationPushDown extends Rule[LogicalPlan] {
+object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
   import DataSourceV2Implicits._
 
-  override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
-    case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
-      val scanBuilder = 
relation.table.asReadable.newScanBuilder(relation.options)
+  def apply(plan: LogicalPlan): LogicalPlan = {
+    
applyColumnPruning(pushdownAggregate(pushDownFilters(createScanBuilder(plan))))
+  }
+
+  private def createScanBuilder(plan: LogicalPlan) = plan.transform {
+    case r: DataSourceV2Relation =>
+      ScanBuilderHolder(r.output, r, 
r.table.asReadable.newScanBuilder(r.options))
+  }
 
-      val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, 
relation.output)
+  private def pushDownFilters(plan: LogicalPlan) = plan.transform {
+    // update the scan builder with filter push down and return a new plan 
with filter pushed
+    case Filter(condition, sHolder: ScanBuilderHolder) =>
+      val filters = splitConjunctivePredicates(condition)
+      val normalizedFilters =
+        DataSourceStrategy.normalizeExprs(filters, sHolder.relation.output)
       val (normalizedFiltersWithSubquery, normalizedFiltersWithoutSubquery) =
         normalizedFilters.partition(SubqueryExpression.hasSubquery)
 
       // `pushedFilters` will be pushed down and evaluated in the underlying 
data sources.
       // `postScanFilters` need to be evaluated after the scan.
       // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet 
row group filter.
       val (pushedFilters, postScanFiltersWithoutSubquery) = 
PushDownUtils.pushFilters(
-        scanBuilder, normalizedFiltersWithoutSubquery)
+        sHolder.builder, normalizedFiltersWithoutSubquery)
       val postScanFilters = postScanFiltersWithoutSubquery ++ 
normalizedFiltersWithSubquery
 
+      logInfo(
+        s"""
+           |Pushing operators to ${sHolder.relation.name}
+           |Pushed Filters: ${pushedFilters.mkString(", ")}
+           |Post-Scan Filters: ${postScanFilters.mkString(",")}
+         """.stripMargin)
+
+      val filterCondition = postScanFilters.reduceLeftOption(And)
+      filterCondition.map(Filter(_, sHolder)).getOrElse(sHolder)
+  }
+
+  def pushdownAggregate(plan: LogicalPlan): LogicalPlan = plan.transform {
+    // update the scan builder with agg pushdown and return a new plan with 
agg pushed
+    case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) =>
+      child match {
+        case ScanOperation(project, filters, sHolder: ScanBuilderHolder)
+          if project.forall(_.isInstanceOf[AttributeReference]) =>
+          sHolder.builder match {
+            case _: SupportsPushDownAggregates =>
+              if (filters.length == 0) { // can't push down aggregate if 
postScanFilters exist
+                val aggregates = resultExpressions.flatMap { expr =>
+                  expr.collect {
+                    case agg: AggregateExpression => agg
+                  }
+                }
+                val pushedAggregates = PushDownUtils
+                  .pushAggregates(sHolder.builder, aggregates, 
groupingExpressions)
+                if (pushedAggregates.isEmpty) {
+                  aggNode // return original plan node
+                } else {
+                  // No need to do column pruning because only the aggregate 
columns are used as
+                  // DataSourceV2ScanRelation output columns. All the other 
columns are not
+                  // included in the output. Since PushDownUtils.pruneColumns 
is not called,
+                  // ScanBuilder.requiredSchema is not pruned, but 
ScanBuilder.requiredSchema is
+                  // not used anyways. The schema for aggregate columns will 
be built in Scan.
+                  val scan = sHolder.builder.build()
+
+                  // scalastyle:off
+                  // use the group by columns and aggregate columns as the 
output columns
+                  // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
+                  // SELECT min(c1), max(c1) FROM t GROUP BY c2;
+                  // Use c2, min(c1), max(c1) as output for 
DataSourceV2ScanRelation
+                  // We want to have the following logical plan:
+                  // == Optimized Logical Plan ==
+                  // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, 
max(max(c1)#22) AS max(c1)#18]
+                  // +- RelationV2[c2#10, min(c1)#21, max(c1)#22]
+                  // scalastyle:on
+                  val newOutput = scan.readSchema().toAttributes
+                  val groupAttrs = groupingExpressions.zip(newOutput).map {
+                    case (a: Attribute, b: Attribute) => b.withExprId(a.exprId)
+                    case (_, b) => b
+                  }
+                  val output = groupAttrs ++ newOutput.drop(groupAttrs.length)
+
+                  logInfo(
+                    s"""
+                       |Pushing operators to ${sHolder.relation.name}
+                       |Pushed Aggregate Functions:
+                       | 
${pushedAggregates.get.aggregateExpressions.mkString(", ")}
+                       |Pushed Group by:
+                       | ${pushedAggregates.get.groupByColumns.mkString(", ")}
+                       |Output: ${output.mkString(", ")}
+                      """.stripMargin)
+
+                  val scanRelation = 
DataSourceV2ScanRelation(sHolder.relation, scan, output)
+                  assert(scanRelation.output.length ==
+                    groupingExpressions.length + aggregates.length)
+
+                  val plan = Aggregate(
+                    output.take(groupingExpressions.length), 
resultExpressions, scanRelation)
+
+                  // scalastyle:off
+                  // Change the optimized logical plan to reflect the pushed 
down aggregate
+                  // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
+                  // SELECT min(c1), max(c1) FROM t GROUP BY c2;
+                  // The original logical plan is
+                  // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS 
max(c1)#18]
+                  // +- RelationV2[c1#9, c2#10] ...
+                  //
+                  // After change the V2ScanRelation output to [c2#10, 
min(c1)#21, max(c1)#22]
+                  // we have the following
+                  // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) 
AS max(c1)#18]
+                  // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
+                  //
+                  // We want to change it to
+                  // == Optimized Logical Plan ==
+                  // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, 
max(max(c1)#22) AS max(c1)#18]
+                  // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
+                  // scalastyle:on
+                  var i = 0
+                  val aggOutput = output.drop(groupAttrs.length)
+                  plan.transformExpressions {
+                    case agg: AggregateExpression =>
+                      i += 1
+                      val aggFunction: aggregate.AggregateFunction =
+                        agg.aggregateFunction match {
+                          case _: aggregate.Max => aggregate.Max(aggOutput(i - 
1))
+                          case _: aggregate.Min => aggregate.Min(aggOutput(i - 
1))
+                          case _: aggregate.Sum => aggregate.Sum(aggOutput(i - 
1))
+                          case _: aggregate.Count => aggregate.Sum(aggOutput(i 
- 1))
+                          case _ => agg.aggregateFunction
+                        }

Review comment:
       we can do `i += 1` here, and the above code can be `aggOutput(i)`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to