This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new aec79534abf [SPARK-39148][SQL] DS V2 aggregate push down can work with
OFFSET or LIMIT
aec79534abf is described below
commit aec79534abf819e7981babc73d13450ea8e49b08
Author: Wenchen Fan <[email protected]>
AuthorDate: Wed Jul 20 11:13:08 2022 +0800
[SPARK-39148][SQL] DS V2 aggregate push down can work with OFFSET or LIMIT
### What changes were proposed in this pull request?
This PR refactors the v2 agg pushdown code. The main change is, now we
don't build the `Scan` immediately when pushing agg. We did it so before
because we want to know the data schema with agg pushed, then we can add cast
when rewriting the query plan after pushdown. But the problem is, we build
`Scan` too early and can't push down any more operators, while it's common to
see LIMIT/OFFSET after agg.
The idea of the refactor is, we don't need to know the data schema with agg
pushed. We just give an expectation (the data type should be the same of Spark
agg functions), use it to define the output of `ScanBuilderHolder`, and then
rewrite the query plan. Later on, when we build the `Scan` and replace
`ScanBuilderHolder` with `DataSourceV2ScanRelation`, we check the actual data
schema and add a `Project` to do type cast if necessary.
### Why are the changes needed?
support pushing down LIMIT/OFFSET after agg.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
updated tests
Closes #37195 from cloud-fan/agg.
Lead-authored-by: Wenchen Fan <[email protected]>
Co-authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../datasources/v2/V2ScanRelationPushDown.scala | 419 +++++++++++----------
.../org/apache/spark/sql/jdbc/JDBCV2Suite.scala | 38 +-
2 files changed, 254 insertions(+), 203 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index 8951c37e127..f1e0e6d80c5 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -26,12 +26,12 @@ import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter,
LeafNode, Limit, LimitAndOffset, LocalLimit, LogicalPlan, Offset,
OffsetAndLimit, Project, Sample, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.expressions.{SortOrder => V2SortOrder}
-import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg,
Count, GeneralAggregateFunc, Sum, UserDefinedAggregateFunc}
+import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg,
Count, CountStar, Max, Min, Sum}
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder,
SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.sources
-import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StructType}
+import org.apache.spark.sql.types.{DataType, DecimalType, IntegerType,
StructType}
import org.apache.spark.sql.util.SchemaUtils._
object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper
with AliasHelper {
@@ -44,6 +44,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with
PredicateHelper wit
pushDownFilters,
pushDownAggregates,
pushDownLimitAndOffset,
+ buildScanWithPushedAggregate,
pruneColumns)
pushdownRules.foldLeft(plan) { (newPlan, pushDownRule) =>
@@ -92,189 +93,201 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan]
with PredicateHelper wit
def pushDownAggregates(plan: LogicalPlan): LogicalPlan = plan.transform {
// update the scan builder with agg pushdown and return a new plan with
agg pushed
- case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) =>
- child match {
- case ScanOperation(project, filters, sHolder: ScanBuilderHolder)
- if filters.isEmpty && CollapseProject.canCollapseExpressions(
- resultExpressions, project, alwaysInline = true) =>
- sHolder.builder match {
- case r: SupportsPushDownAggregates =>
- val aliasMap = getAliasMap(project)
- val actualResultExprs =
resultExpressions.map(replaceAliasButKeepName(_, aliasMap))
- val actualGroupExprs = groupingExpressions.map(replaceAlias(_,
aliasMap))
-
- val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression,
Int]
- val aggregates = collectAggregates(actualResultExprs,
aggExprToOutputOrdinal)
- val normalizedAggregates = DataSourceStrategy.normalizeExprs(
- aggregates,
sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]]
- val normalizedGroupingExpressions =
DataSourceStrategy.normalizeExprs(
- actualGroupExprs, sHolder.relation.output)
- val translatedAggregates =
DataSourceStrategy.translateAggregation(
- normalizedAggregates, normalizedGroupingExpressions)
- val (finalResultExpressions, finalAggregates,
finalTranslatedAggregates) = {
- if (translatedAggregates.isEmpty ||
- r.supportCompletePushDown(translatedAggregates.get) ||
-
translatedAggregates.get.aggregateExpressions().forall(!_.isInstanceOf[Avg])) {
- (actualResultExprs, aggregates, translatedAggregates)
- } else {
- // scalastyle:off
- // The data source doesn't support the complete push-down of
this aggregation.
- // Here we translate `AVG` to `SUM / COUNT`, so that it's
more likely to be
- // pushed, completely or partially.
- // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
- // SELECT avg(c1) FROM t GROUP BY c2;
- // The original logical plan is
- // Aggregate [c2#10],[avg(c1#9) AS avg(c1)#19]
- // +- ScanOperation[...]
- //
- // After convert avg(c1#9) to sum(c1#9)/count(c1#9)
- // we have the following
- // Aggregate [c2#10],[sum(c1#9)/count(c1#9) AS avg(c1)#19]
- // +- ScanOperation[...]
- // scalastyle:on
- val newResultExpressions = actualResultExprs.map { expr =>
- expr.transform {
- case AggregateExpression(avg: aggregate.Average, _,
isDistinct, _, _) =>
- val sum =
aggregate.Sum(avg.child).toAggregateExpression(isDistinct)
- val count =
aggregate.Count(avg.child).toAggregateExpression(isDistinct)
- avg.evaluateExpression transform {
- case a: Attribute if a.semanticEquals(avg.sum) =>
- addCastIfNeeded(sum, avg.sum.dataType)
- case a: Attribute if a.semanticEquals(avg.count) =>
- addCastIfNeeded(count, avg.count.dataType)
- }
- }
- }.asInstanceOf[Seq[NamedExpression]]
- // Because aggregate expressions changed, translate them
again.
- aggExprToOutputOrdinal.clear()
- val newAggregates =
- collectAggregates(newResultExpressions,
aggExprToOutputOrdinal)
- val newNormalizedAggregates =
DataSourceStrategy.normalizeExprs(
- newAggregates,
sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]]
- (newResultExpressions, newAggregates,
DataSourceStrategy.translateAggregation(
- newNormalizedAggregates, normalizedGroupingExpressions))
+ case agg: Aggregate => rewriteAggregate(agg)
+ }
+
+ private def rewriteAggregate(agg: Aggregate): LogicalPlan = agg.child match {
+ case ScanOperation(project, Nil, holder @ ScanBuilderHolder(_, _,
+ r: SupportsPushDownAggregates)) if
CollapseProject.canCollapseExpressions(
+ agg.aggregateExpressions, project, alwaysInline = true) =>
+ val aliasMap = getAliasMap(project)
+ val actualResultExprs =
agg.aggregateExpressions.map(replaceAliasButKeepName(_, aliasMap))
+ val actualGroupExprs = agg.groupingExpressions.map(replaceAlias(_,
aliasMap))
+
+ val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
+ val aggregates = collectAggregates(actualResultExprs,
aggExprToOutputOrdinal)
+ val normalizedAggExprs = DataSourceStrategy.normalizeExprs(
+ aggregates,
holder.relation.output).asInstanceOf[Seq[AggregateExpression]]
+ val normalizedGroupingExpr = DataSourceStrategy.normalizeExprs(
+ actualGroupExprs, holder.relation.output)
+ val translatedAggOpt = DataSourceStrategy.translateAggregation(
+ normalizedAggExprs, normalizedGroupingExpr)
+ if (translatedAggOpt.isEmpty) {
+ // Cannot translate the catalyst aggregate, return the query plan
unchanged.
+ return agg
+ }
+
+ val (finalResultExprs, finalAggExprs, translatedAgg,
canCompletePushDown) = {
+ if (r.supportCompletePushDown(translatedAggOpt.get)) {
+ (actualResultExprs, normalizedAggExprs, translatedAggOpt.get, true)
+ } else if
(!translatedAggOpt.get.aggregateExpressions().exists(_.isInstanceOf[Avg])) {
+ (actualResultExprs, normalizedAggExprs, translatedAggOpt.get, false)
+ } else {
+ // scalastyle:off
+ // The data source doesn't support the complete push-down of this
aggregation.
+ // Here we translate `AVG` to `SUM / COUNT`, so that it's more
likely to be
+ // pushed, completely or partially.
+ // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
+ // SELECT avg(c1) FROM t GROUP BY c2;
+ // The original logical plan is
+ // Aggregate [c2#10],[avg(c1#9) AS avg(c1)#19]
+ // +- ScanOperation[...]
+ //
+ // After convert avg(c1#9) to sum(c1#9)/count(c1#9)
+ // we have the following
+ // Aggregate [c2#10],[sum(c1#9)/count(c1#9) AS avg(c1)#19]
+ // +- ScanOperation[...]
+ // scalastyle:on
+ val newResultExpressions = actualResultExprs.map { expr =>
+ expr.transform {
+ case AggregateExpression(avg: aggregate.Average, _, isDistinct,
_, _) =>
+ val sum =
aggregate.Sum(avg.child).toAggregateExpression(isDistinct)
+ val count =
aggregate.Count(avg.child).toAggregateExpression(isDistinct)
+ avg.evaluateExpression transform {
+ case a: Attribute if a.semanticEquals(avg.sum) =>
+ addCastIfNeeded(sum, avg.sum.dataType)
+ case a: Attribute if a.semanticEquals(avg.count) =>
+ addCastIfNeeded(count, avg.count.dataType)
}
- }
+ }
+ }.asInstanceOf[Seq[NamedExpression]]
+ // Because aggregate expressions changed, translate them again.
+ aggExprToOutputOrdinal.clear()
+ val newAggregates =
+ collectAggregates(newResultExpressions, aggExprToOutputOrdinal)
+ val newNormalizedAggExprs = DataSourceStrategy.normalizeExprs(
+ newAggregates,
holder.relation.output).asInstanceOf[Seq[AggregateExpression]]
+ val newTranslatedAggOpt = DataSourceStrategy.translateAggregation(
+ newNormalizedAggExprs, normalizedGroupingExpr)
+ if (newTranslatedAggOpt.isEmpty) {
+ // Ideally we should never reach here. But if we end up with not
able to translate
+ // new aggregate with AVG replaced by SUM/COUNT, revert to the
original one.
+ (actualResultExprs, normalizedAggExprs, translatedAggOpt.get,
false)
+ } else {
+ (newResultExpressions, newNormalizedAggExprs,
newTranslatedAggOpt.get,
+ r.supportCompletePushDown(newTranslatedAggOpt.get))
+ }
+ }
+ }
- if (finalTranslatedAggregates.isEmpty) {
- aggNode // return original plan node
- } else if
(!r.supportCompletePushDown(finalTranslatedAggregates.get) &&
- !supportPartialAggPushDown(finalTranslatedAggregates.get)) {
- aggNode // return original plan node
- } else {
- val pushedAggregates =
finalTranslatedAggregates.filter(r.pushAggregation)
- if (pushedAggregates.isEmpty) {
- aggNode // return original plan node
- } else {
- // No need to do column pruning because only the aggregate
columns are used as
- // DataSourceV2ScanRelation output columns. All the other
columns are not
- // included in the output.
- val scan = sHolder.builder.build()
-
- // scalastyle:off
- // use the group by columns and aggregate columns as the
output columns
- // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
- // SELECT min(c1), max(c1) FROM t GROUP BY c2;
- // Use c2, min(c1), max(c1) as output for
DataSourceV2ScanRelation
- // We want to have the following logical plan:
- // == Optimized Logical Plan ==
- // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17,
max(max(c1)#22) AS max(c1)#18]
- // +- RelationV2[c2#10, min(c1)#21, max(c1)#22]
- // scalastyle:on
- val newOutput = scan.readSchema().toAttributes
- assert(newOutput.length == groupingExpressions.length +
finalAggregates.length)
- val groupByExprToOutputOrdinal =
mutable.HashMap.empty[Expression, Int]
- val groupAttrs =
normalizedGroupingExpressions.zip(newOutput).zipWithIndex.map {
- case ((a: Attribute, b: Attribute), _) =>
b.withExprId(a.exprId)
- case ((expr, attr), ordinal) =>
- if
(!groupByExprToOutputOrdinal.contains(expr.canonicalized)) {
- groupByExprToOutputOrdinal(expr.canonicalized) =
ordinal
- }
- attr
- }
- val aggOutput = newOutput.drop(groupAttrs.length)
- val output = groupAttrs ++ aggOutput
-
- logInfo(
- s"""
- |Pushing operators to ${sHolder.relation.name}
- |Pushed Aggregate Functions:
- |
${pushedAggregates.get.aggregateExpressions.mkString(", ")}
- |Pushed Group by:
- | ${pushedAggregates.get.groupByExpressions.mkString(",
")}
- |Output: ${output.mkString(", ")}
- """.stripMargin)
-
- val wrappedScan = getWrappedScan(scan, sHolder,
pushedAggregates)
- val scanRelation =
- DataSourceV2ScanRelation(sHolder.relation, wrappedScan,
output)
- if (r.supportCompletePushDown(pushedAggregates.get)) {
- val projectExpressions = finalResultExpressions.map { expr
=>
- expr.transformDown {
- case agg: AggregateExpression =>
- val ordinal =
aggExprToOutputOrdinal(agg.canonicalized)
- val child =
- addCastIfNeeded(aggOutput(ordinal),
agg.resultAttribute.dataType)
- Alias(child,
agg.resultAttribute.name)(agg.resultAttribute.exprId)
- case expr if
groupByExprToOutputOrdinal.contains(expr.canonicalized) =>
- val ordinal =
groupByExprToOutputOrdinal(expr.canonicalized)
- addCastIfNeeded(groupAttrs(ordinal), expr.dataType)
- }
- }.asInstanceOf[Seq[NamedExpression]]
- Project(projectExpressions, scanRelation)
+ if (!canCompletePushDown && !supportPartialAggPushDown(translatedAgg)) {
+ return agg
+ }
+ if (!r.pushAggregation(translatedAgg)) {
+ return agg
+ }
+
+ // scalastyle:off
+ // We name the output columns of group expressions and aggregate
functions by
+ // ordinal: `group_col_0`, `group_col_1`, ..., `agg_func_0`,
`agg_func_1`, ...
+ // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
+ // SELECT min(c1), max(c1) FROM t GROUP BY c2;
+ // Use group_col_0, agg_func_0, agg_func_1 as output for
ScanBuilderHolder.
+ // We want to have the following logical plan:
+ // == Optimized Logical Plan ==
+ // Aggregate [group_col_0#10], [min(agg_func_0#21) AS min(c1)#17,
max(agg_func_1#22) AS max(c1)#18]
+ // +- ScanBuilderHolder[group_col_0#10, agg_func_0#21, agg_func_1#22]
+ // Later, we build the `Scan` instance and convert ScanBuilderHolder to
DataSourceV2ScanRelation.
+ // scalastyle:on
+ val groupOutput = normalizedGroupingExpr.zipWithIndex.map { case (e, i)
=>
+ AttributeReference(s"group_col_$i", e.dataType)()
+ }
+ val aggOutput = finalAggExprs.zipWithIndex.map { case (e, i) =>
+ AttributeReference(s"agg_func_$i", e.dataType)()
+ }
+ val newOutput = groupOutput ++ aggOutput
+ val groupByExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
+ normalizedGroupingExpr.zipWithIndex.foreach { case (expr, ordinal) =>
+ if (!groupByExprToOutputOrdinal.contains(expr.canonicalized)) {
+ groupByExprToOutputOrdinal(expr.canonicalized) = ordinal
+ }
+ }
+
+ holder.pushedAggregate = Some(translatedAgg)
+ holder.output = newOutput
+ logInfo(
+ s"""
+ |Pushing operators to ${holder.relation.name}
+ |Pushed Aggregate Functions:
+ | ${translatedAgg.aggregateExpressions().mkString(", ")}
+ |Pushed Group by:
+ | ${translatedAgg.groupByExpressions.mkString(", ")}
+ """.stripMargin)
+
+ if (canCompletePushDown) {
+ val projectExpressions = finalResultExprs.map { expr =>
+ expr.transformDown {
+ case agg: AggregateExpression =>
+ val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
+ Alias(aggOutput(ordinal),
agg.resultAttribute.name)(agg.resultAttribute.exprId)
+ case expr if
groupByExprToOutputOrdinal.contains(expr.canonicalized) =>
+ val ordinal = groupByExprToOutputOrdinal(expr.canonicalized)
+ expr match {
+ case ne: NamedExpression => Alias(groupOutput(ordinal),
ne.name)(ne.exprId)
+ case _ => groupOutput(ordinal)
+ }
+ }
+ }.asInstanceOf[Seq[NamedExpression]]
+ Project(projectExpressions, holder)
+ } else {
+ // scalastyle:off
+ // Change the optimized logical plan to reflect the pushed down
aggregate
+ // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
+ // SELECT min(c1), max(c1) FROM t GROUP BY c2;
+ // The original logical plan is
+ // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
+ // +- RelationV2[c1#9, c2#10] ...
+ //
+ // After change the V2ScanRelation output to [c2#10, min(c1)#21,
max(c1)#22]
+ // we have the following
+ // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS
max(c1)#18]
+ // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
+ //
+ // We want to change it to
+ // == Optimized Logical Plan ==
+ // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22)
AS max(c1)#18]
+ // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
+ // scalastyle:on
+ val aggExprs = finalResultExprs.map(_.transform {
+ case agg: AggregateExpression =>
+ val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
+ val aggAttribute = aggOutput(ordinal)
+ val aggFunction: aggregate.AggregateFunction =
+ agg.aggregateFunction match {
+ case max: aggregate.Max =>
+ max.copy(child = aggAttribute)
+ case min: aggregate.Min =>
+ min.copy(child = aggAttribute)
+ case sum: aggregate.Sum =>
+ // To keep the dataType of `Sum` unchanged, we need to cast
the
+ // data-source-aggregated result to `Sum.child.dataType` if
it's decimal.
+ // See `SumBase.resultType`
+ val newChild = if (sum.dataType.isInstanceOf[DecimalType]) {
+ addCastIfNeeded(aggAttribute, sum.child.dataType)
} else {
- val plan =
Aggregate(output.take(groupingExpressions.length),
- finalResultExpressions, scanRelation)
-
- // scalastyle:off
- // Change the optimized logical plan to reflect the pushed
down aggregate
- // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
- // SELECT min(c1), max(c1) FROM t GROUP BY c2;
- // The original logical plan is
- // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9)
AS max(c1)#18]
- // +- RelationV2[c1#9, c2#10] ...
- //
- // After change the V2ScanRelation output to [c2#10,
min(c1)#21, max(c1)#22]
- // we have the following
- // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9)
AS max(c1)#18]
- // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
- //
- // We want to change it to
- // == Optimized Logical Plan ==
- // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17,
max(max(c1)#22) AS max(c1)#18]
- // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
- // scalastyle:on
- plan.transformExpressions {
- case agg: AggregateExpression =>
- val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
- val aggAttribute = aggOutput(ordinal)
- val aggFunction: aggregate.AggregateFunction =
- agg.aggregateFunction match {
- case max: aggregate.Max =>
- max.copy(child = addCastIfNeeded(aggAttribute,
max.child.dataType))
- case min: aggregate.Min =>
- min.copy(child = addCastIfNeeded(aggAttribute,
min.child.dataType))
- case sum: aggregate.Sum =>
- sum.copy(child = addCastIfNeeded(aggAttribute,
sum.child.dataType))
- case _: aggregate.Count =>
- aggregate.Sum(addCastIfNeeded(aggAttribute,
LongType))
- case other => other
- }
- agg.copy(aggregateFunction = aggFunction)
- case expr if
groupByExprToOutputOrdinal.contains(expr.canonicalized) =>
- val ordinal =
groupByExprToOutputOrdinal(expr.canonicalized)
- addCastIfNeeded(groupAttrs(ordinal), expr.dataType)
- }
+ aggAttribute
}
- }
+ sum.copy(child = newChild)
+ case _: aggregate.Count =>
+ aggregate.Sum(aggAttribute)
+ case other => other
}
- case _ => aggNode
- }
- case _ => aggNode
+ agg.copy(aggregateFunction = aggFunction)
+ case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized)
=>
+ val ordinal = groupByExprToOutputOrdinal(expr.canonicalized)
+ expr match {
+ case ne: NamedExpression => Alias(groupOutput(ordinal),
ne.name)(ne.exprId)
+ case _ => groupOutput(ordinal)
+ }
+ }).asInstanceOf[Seq[NamedExpression]]
+ Aggregate(groupOutput, aggExprs, holder)
}
+
+ case _ => agg
}
- private def collectAggregates(resultExpressions: Seq[NamedExpression],
+ private def collectAggregates(
+ resultExpressions: Seq[NamedExpression],
aggExprToOutputOrdinal: mutable.HashMap[Expression, Int]):
Seq[AggregateExpression] = {
var ordinal = 0
resultExpressions.flatMap { expr =>
@@ -292,15 +305,12 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan]
with PredicateHelper wit
}
private def supportPartialAggPushDown(agg: Aggregation): Boolean = {
- // We don't know the agg buffer of `GeneralAggregateFunc`, so can't do
partial agg push down.
- // If `Sum`, `Count`, `Avg` with distinct, can't do partial agg push down.
- agg.aggregateExpressions().isEmpty || agg.aggregateExpressions().exists {
+ // We can only partially push down min/max/sum/count without DISTINCT.
+ agg.aggregateExpressions().isEmpty || agg.aggregateExpressions().forall {
case sum: Sum => !sum.isDistinct
case count: Count => !count.isDistinct
- case avg: Avg => !avg.isDistinct
- case _: GeneralAggregateFunc => false
- case _: UserDefinedAggregateFunc => false
- case _ => true
+ case _: Min | _: Max | _: CountStar => true
+ case _ => false
}
}
@@ -311,6 +321,26 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan]
with PredicateHelper wit
Cast(expression, expectedDataType)
}
+ def buildScanWithPushedAggregate(plan: LogicalPlan): LogicalPlan =
plan.transform {
+ case holder: ScanBuilderHolder if holder.pushedAggregate.isDefined =>
+ // No need to do column pruning because only the aggregate columns are
used as
+ // DataSourceV2ScanRelation output columns. All the other columns are not
+ // included in the output.
+ val scan = holder.builder.build()
+ val realOutput = scan.readSchema().toAttributes
+ assert(realOutput.length == holder.output.length,
+ "The data source returns unexpected number of columns")
+ val wrappedScan = getWrappedScan(scan, holder)
+ val scanRelation = DataSourceV2ScanRelation(holder.relation,
wrappedScan, realOutput)
+ val projectList = realOutput.zip(holder.output).map { case (a1, a2) =>
+ // The data source may return columns with arbitrary data types and
it's safer to cast them
+ // to the expected data type.
+ assert(Cast.canCast(a1.dataType, a2.dataType))
+ Alias(addCastIfNeeded(a1, a2.dataType), a2.name)(a2.exprId)
+ }
+ Project(projectList, scanRelation)
+ }
+
def pruneColumns(plan: LogicalPlan): LogicalPlan = plan.transform {
case ScanOperation(project, filters, sHolder: ScanBuilderHolder) =>
// column pruning
@@ -325,7 +355,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan]
with PredicateHelper wit
|Output: ${output.mkString(", ")}
""".stripMargin)
- val wrappedScan = getWrappedScan(scan, sHolder,
Option.empty[Aggregation])
+ val wrappedScan = getWrappedScan(scan, sHolder)
val scanRelation = DataSourceV2ScanRelation(sHolder.relation,
wrappedScan, output)
@@ -378,8 +408,12 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan]
with PredicateHelper wit
}
(operation, isPushed && !isPartiallyPushed)
case s @ Sort(order, _, operation @ ScanOperation(project, filter,
sHolder: ScanBuilderHolder))
- if filter.isEmpty && CollapseProject.canCollapseExpressions(
- order, project, alwaysInline = true) =>
+ // Without building the Scan, we do not know the resulting column
names after aggregate
+ // push-down, and thus can't push down Top-N which needs to know the
ordering column names.
+ // TODO: we can support simple cases like GROUP BY columns directly
and ORDER BY the same
+ // columns, which we know the resulting column names: the
original table columns.
+ if sHolder.pushedAggregate.isEmpty && filter.isEmpty &&
+ CollapseProject.canCollapseExpressions(order, project, alwaysInline
= true) =>
val aliasMap = getAliasMap(project)
val newOrder = order.map(replaceAlias(_,
aliasMap)).asInstanceOf[Seq[SortOrder]]
val normalizedOrders = DataSourceStrategy.normalizeExprs(
@@ -480,10 +514,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan]
with PredicateHelper wit
}
}
- private def getWrappedScan(
- scan: Scan,
- sHolder: ScanBuilderHolder,
- aggregation: Option[Aggregation]): Scan = {
+ private def getWrappedScan(scan: Scan, sHolder: ScanBuilderHolder): Scan = {
scan match {
case v1: V1Scan =>
val pushedFilters = sHolder.builder match {
@@ -491,7 +522,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan]
with PredicateHelper wit
f.pushedFilters()
case _ => Array.empty[sources.Filter]
}
- val pushedDownOperators = PushedDownOperators(aggregation,
sHolder.pushedSample,
+ val pushedDownOperators = PushedDownOperators(sHolder.pushedAggregate,
sHolder.pushedSample,
sHolder.pushedLimit, sHolder.pushedOffset, sHolder.sortOrders,
sHolder.pushedPredicates)
V1ScanWrapper(v1, pushedFilters, pushedDownOperators)
case _ => scan
@@ -500,7 +531,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan]
with PredicateHelper wit
}
case class ScanBuilderHolder(
- output: Seq[AttributeReference],
+ var output: Seq[AttributeReference],
relation: DataSourceV2Relation,
builder: ScanBuilder) extends LeafNode {
var pushedLimit: Option[Int] = None
@@ -512,6 +543,8 @@ case class ScanBuilderHolder(
var pushedSample: Option[TableSampleInfo] = None
var pushedPredicates: Seq[Predicate] = Seq.empty[Predicate]
+
+ var pushedAggregate: Option[Aggregation] = None
}
// A wrapper for v1 scan to carry the translated filters and the handled ones,
along with
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index 7e772c0febb..d64b1815007 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -265,9 +265,13 @@ class JDBCV2Suite extends QueryTest with
SharedSparkSession with ExplainSuiteHel
.table("h2.test.employee")
.groupBy("DEPT").sum("SALARY")
.limit(1)
- checkLimitRemoved(df4, false)
+ checkAggregateRemoved(df4)
+ checkLimitRemoved(df4)
checkPushedInfo(df4,
- "PushedAggregates: [SUM(SALARY)], PushedFilters: [],
PushedGroupByExpressions: [DEPT], ")
+ "PushedAggregates: [SUM(SALARY)]",
+ "PushedGroupByExpressions: [DEPT]",
+ "PushedFilters: []",
+ "PushedLimit: LIMIT 1")
checkAnswer(df4, Seq(Row(1, 19000.00)))
val name = udf { (x: String) => x.matches("cat|dav|amy") }
@@ -340,9 +344,13 @@ class JDBCV2Suite extends QueryTest with
SharedSparkSession with ExplainSuiteHel
.table("h2.test.employee")
.groupBy("DEPT").sum("SALARY")
.offset(1)
- checkOffsetRemoved(df5, false)
+ checkAggregateRemoved(df5)
+ checkLimitRemoved(df5)
checkPushedInfo(df5,
- "PushedAggregates: [SUM(SALARY)], PushedFilters: [],
PushedGroupByExpressions: [DEPT], ")
+ "PushedAggregates: [SUM(SALARY)]",
+ "PushedGroupByExpressions: [DEPT]",
+ "PushedFilters: []",
+ "PushedOffset: OFFSET 1")
checkAnswer(df5, Seq(Row(2, 22000.00), Row(6, 12000.00)))
val name = udf { (x: String) => x.matches("cat|dav|amy") }
@@ -477,10 +485,15 @@ class JDBCV2Suite extends QueryTest with
SharedSparkSession with ExplainSuiteHel
.groupBy("DEPT").sum("SALARY")
.limit(2)
.offset(1)
- checkLimitRemoved(df10, false)
- checkOffsetRemoved(df10, false)
+ checkAggregateRemoved(df10)
+ checkLimitRemoved(df10)
+ checkOffsetRemoved(df10)
checkPushedInfo(df10,
- "PushedAggregates: [SUM(SALARY)], PushedFilters: [],
PushedGroupByExpressions: [DEPT], ")
+ "PushedAggregates: [SUM(SALARY)]",
+ "PushedGroupByExpressions: [DEPT]",
+ "PushedFilters: []",
+ "PushedLimit: LIMIT 2",
+ "PushedOffset: OFFSET 1")
checkAnswer(df10, Seq(Row(2, 22000.00)))
val name = udf { (x: String) => x.matches("cat|dav|amy") }
@@ -612,10 +625,15 @@ class JDBCV2Suite extends QueryTest with
SharedSparkSession with ExplainSuiteHel
checkAnswer(df9, Seq(Row(2, "david", 10000.00, 1300.0, true)))
val df10 = sql("SELECT dept, sum(salary) FROM h2.test.employee group by
dept LIMIT 1 OFFSET 1")
- checkLimitRemoved(df10, false)
- checkOffsetRemoved(df10, false)
+ checkAggregateRemoved(df10)
+ checkLimitRemoved(df10)
+ checkOffsetRemoved(df10)
checkPushedInfo(df10,
- "PushedAggregates: [SUM(SALARY)], PushedFilters: [],
PushedGroupByExpressions: [DEPT], ")
+ "PushedAggregates: [SUM(SALARY)]",
+ "PushedGroupByExpressions: [DEPT]",
+ "PushedFilters: []",
+ "PushedLimit: LIMIT 2",
+ "PushedOffset: OFFSET 1")
checkAnswer(df10, Seq(Row(2, 22000.00)))
val name = udf { (x: String) => x.matches("cat|dav|amy") }
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]