maropu commented on a change in pull request #27058: [SPARK-30276][SQL] Support
Filter expression allows simultaneous use of DISTINCT
URL: https://github.com/apache/spark/pull/27058#discussion_r368808060
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
##########
@@ -148,24 +207,106 @@ object RewriteDistinctAggregates extends
Rule[LogicalPlan] {
val distinctAggs = exprs.flatMap { _.collect {
case ae: AggregateExpression if ae.isDistinct => ae
}}
- // We need at least two distinct aggregates for this rule because
aggregation
- // strategy can handle a single distinct group.
+ // This rule serves two purposes:
+ // One is to rewrite when there exists at least two distinct aggregates.
We need at least
+ // two distinct aggregates for this rule because aggregation strategy can
handle a single
+ // distinct group.
+ // Another is to expand distinct aggregates which exists filter clause so
that we can
+ // evaluate the filter locally.
// This check can produce false-positives, e.g., SUM(DISTINCT a) &
COUNT(DISTINCT a).
- distinctAggs.size > 1
+ distinctAggs.size >= 1 || distinctAggs.exists(_.filter.isDefined)
}
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
- case a: Aggregate if mayNeedtoRewrite(a.aggregateExpressions) => rewrite(a)
+ case a: Aggregate if mayNeedtoRewrite(a.aggregateExpressions) =>
+ val expandAggregate = extractFiltersInDistinctAggregate(a)
+ rewriteDistinctAggregate(expandAggregate)
}
- def rewrite(a: Aggregate): Aggregate = {
+ private def extractFiltersInDistinctAggregate(a: Aggregate): Aggregate = {
+ val aggExpressions = collectAggregateExprs(a)
+ val (distinctAggExpressions, regularAggExpressions) =
aggExpressions.partition(_.isDistinct)
+ if (distinctAggExpressions.exists(_.filter.isDefined)) {
+ // Setup expand for the 'regular' aggregate expressions. Because we will
construct a new
+ // aggregate, the children of the distinct aggregates will be changed to
the generate
+ // ones, so we need creates new references to avoid collisions between
distinct and
+ // regular aggregate children.
+ val regularAggExprs =
regularAggExpressions.filter(_.children.exists(!_.foldable))
+ val regularFunChildren = regularAggExprs
+ .flatMap(_.aggregateFunction.children.filter(!_.foldable))
+ val regularFilterAttrs = regularAggExprs.flatMap(_.filterAttributes)
+ val regularAggChildren = (regularFunChildren ++
regularFilterAttrs).distinct
+ val regularAggChildAttrMap =
regularAggChildren.map(expressionAttributePair)
+ val regularAggChildAttrLookup = regularAggChildAttrMap.toMap
+ val regularAggMap = regularAggExprs.map {
+ case ae @ AggregateExpression(af, _, _, filter, _) =>
+ val newChildren = af.children.map(c =>
regularAggChildAttrLookup.getOrElse(c, c))
+ val raf =
af.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
+ val filterOpt = filter.map(_.transform {
+ case a: Attribute => regularAggChildAttrLookup.getOrElse(a, a)
+ })
+ val aggExpr = ae.copy(aggregateFunction = raf, filter = filterOpt)
+ (ae, aggExpr)
+ }
- // Collect all aggregate expressions.
- val aggExpressions = a.aggregateExpressions.flatMap { e =>
- e.collect {
- case ae: AggregateExpression => ae
+ // Setup expand for the distinct aggregate expressions.
+ val distinctAggExprs = distinctAggExpressions.filter(e =>
e.children.exists(!_.foldable))
+ val (projections, expressionAttrs, aggExprPairs) = distinctAggExprs.map {
+ case ae @ AggregateExpression(af, _, _, filter, _) =>
+ // Why do we need to construct the `exprId` ?
+ // First, In order to reduce costs, it is better to handle the
filter clause locally.
+ // e.g. COUNT (DISTINCT a) FILTER (WHERE id > 1), evaluate expression
+ // If(id > 1) 'a else null first, and use the result as output.
+ // Second, If at least two DISTINCT aggregate expression which may
references the
+ // same attributes. We need to construct the generate attributes so
as the output not
+ // lost. e.g. SUM (DISTINCT a), COUNT (DISTINCT a) FILTER (WHERE id
> 1) will output
+ // attribute '_gen_distinct-1 and attribute '_gen_distinct-2 instead
of two 'a.
+ // Note: We just need to illusion the expression with filter clause.
+ // The illusionary mechanism may result in multiple distinct
aggregations uses
+ // different column, so we still need to call `rewrite`.
+ val exprId = NamedExpression.newExprId.id
+ val unfoldableChildren = af.children.filter(!_.foldable)
+ val exprAttrs = unfoldableChildren.map { e =>
+ (e, AttributeReference(s"_gen_distinct_$exprId", e.dataType,
nullable = true)())
+ }
+ val exprAttrLookup = exprAttrs.toMap
+ val newChildren = af.children.map(c => exprAttrLookup.getOrElse(c,
c))
+ val raf =
af.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
+ val aggExpr = ae.copy(aggregateFunction = raf, filter = None)
+ // Expand projection
+ val projection = unfoldableChildren.map {
+ case e if filter.isDefined => If(filter.get, e, nullify(e))
+ case e => e
+ }
+ (projection, exprAttrs, (ae, aggExpr))
+ }.unzip3
+ val distinctAggChildAttrs = expressionAttrs.flatten.map(_._2)
+ val allAggAttrs = regularAggChildAttrMap.map(_._2) ++
distinctAggChildAttrs
+ // Construct the aggregate input projection.
+ val rewriteAggProjections =
+ Seq(a.groupingExpressions ++ regularAggChildren ++ projections.flatten)
+ val groupByMap = a.groupingExpressions.collect {
+ case ne: NamedExpression => ne -> ne.toAttribute
+ case e => e -> AttributeReference(e.sql, e.dataType, e.nullable)()
}
+ val groupByAttrs = groupByMap.map(_._2)
+ // Construct the expand operator.
+ val expand = Expand(rewriteAggProjections, groupByAttrs ++ allAggAttrs,
a.child)
+ val rewriteAggExprLookup = (aggExprPairs ++ regularAggMap).toMap
+ val patchedAggExpressions = a.aggregateExpressions.map { e =>
+ e.transformDown {
+ case ae: AggregateExpression => rewriteAggExprLookup.getOrElse(ae,
ae)
+ }.asInstanceOf[NamedExpression]
+ }
+ val expandAggregate = Aggregate(groupByAttrs, patchedAggExpressions,
expand)
+ expandAggregate
+ } else {
+ a
}
+ }
+
+ private def rewriteDistinctAggregate(a: Aggregate): Aggregate = {
Review comment:
nit: `rewriteDistinctAggregate`->`rewriteDistinctAggregates`
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]