Github user marmbrus commented on a diff in the pull request:
https://github.com/apache/spark/pull/9406#discussion_r44199453
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
---
@@ -213,3 +216,178 @@ object Utils {
case other => None
}
}
+
+/**
+ * This rule rewrites an aggregate query with multiple distinct clauses
into an expanded double
+ * aggregation in which the regular aggregation expressions and every
distinct clause is aggregated
+ * in a separate group. The results are then combined in a second
aggregate.
+ *
+ * TODO Expression cannocalization
+ * TODO Eliminate foldable expressions from distinct clauses.
+ * TODO This eliminates all distinct expressions. We could safely pass one
to the aggregate
+ * operator. Perhaps this is a good thing? It is much simpler to plan
later on...
+ */
+object MultipleDistinctRewriter extends Rule[LogicalPlan] {
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case a: Aggregate => rewrite(a)
+ case p => p
+ }
+
+ def rewrite(a: Aggregate): Aggregate = {
+
+ // Collect all aggregate expressions.
+ val aggExpressions = a.aggregateExpressions.flatMap { e =>
+ e.collect {
+ case ae: AggregateExpression2 => ae
+ }
+ }
+
+ // Extract distinct aggregate expressions.
+ val distinctAggGroups = aggExpressions
+ .filter(_.isDistinct)
+ .groupBy(_.aggregateFunction.children.toSet)
+
+ // Only continue to rewrite if there is more than one distinct group.
+ if (distinctAggGroups.size > 1) {
+ // Create the attributes for the grouping id and the group by clause.
+ val gid = new AttributeReference("gid", IntegerType, false)()
+ val groupByMap = a.groupingExpressions.collect {
+ case ne: NamedExpression => ne -> ne.toAttribute
+ case e => e -> new AttributeReference(e.prettyName, e.dataType,
e.nullable)()
+ }
+ val groupByAttrs = groupByMap.map(_._2)
+
+ // Functions used to modify aggregate functions and their inputs.
+ def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid,
id), e, nullify(e))
+ def patchAggregateFunctionChildren(
+ af: AggregateFunction2,
+ id: Literal,
+ attrs: Map[Expression, Expression]): AggregateFunction2 = {
+ af.withNewChildren(af.children.map { case afc =>
+ evalWithinGroup(id, attrs(afc))
+ }).asInstanceOf[AggregateFunction2]
+ }
+
+ // Setup unique distinct aggregate children.
+ val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq
+ val distinctAggChildAttrMap =
distinctAggChildren.map(expressionAttributePair).toMap
+ val distinctAggChildAttrs = distinctAggChildAttrMap.values.toSeq
+
+ // Setup expand & aggregate operators for distinct aggregate
expressions.
+ val distinctAggOperatorMap =
distinctAggGroups.toSeq.zipWithIndex.map {
+ case ((group, expressions), i) =>
+ val id = Literal(i + 1)
+
+ // Expand projection
+ val projection = distinctAggChildren.map {
+ case e if group.contains(e) => e
+ case e => nullify(e)
+ } :+ id
+
+ // Final aggregate
+ val operators = expressions.map { e =>
+ val af = e.aggregateFunction
+ val naf = patchAggregateFunctionChildren(af, id,
distinctAggChildAttrMap)
+ (e, e.copy(aggregateFunction = naf, isDistinct = false))
+ }
+
+ (projection, operators)
+ }
+
+ // Setup expand for the 'regular' aggregate expressions.
+ val regularAggExprs = aggExpressions.filter(!_.isDistinct)
+ val regularAggChildren =
regularAggExprs.flatMap(_.aggregateFunction.children).distinct
+ val regularAggChildAttrMap =
regularAggChildren.map(expressionAttributePair).toMap
+
+ // Setup aggregates for 'regular' aggregate expressions.
+ val regularGroupId = Literal(0)
+ val regularAggOperatorMap = regularAggExprs.map { e =>
+ // Perform the actual aggregation in the initial aggregate.
+ val af = patchAggregateFunctionChildren(
+ e.aggregateFunction,
+ regularGroupId,
+ regularAggChildAttrMap)
+ val a = Alias(e.copy(aggregateFunction = af), e.toString)()
+
+ // Get the result of the first aggregate in the last aggregate.
+ val b = AggregateExpression2(
+ aggregate.First(evalWithinGroup(regularGroupId, a.toAttribute),
Literal(true)),
+ mode = Complete,
+ isDistinct = false)
+
+ // Some aggregate functions (COUNT) have the special property that
they can return a
+ // non-null result without any input. We need to make sure we
return a result in this case.
+ val c = af.defaultResult match {
+ case Some(lit) => Coalesce(Seq(b, lit))
+ case None => b
+ }
+
+ (e, a, c)
+ }
+
+ // Construct the regular aggregate input projection only if we need
one.
+ val regularAggProjection = if (regularAggExprs.nonEmpty) {
+ Seq(a.groupingExpressions ++
+ distinctAggChildren.map(nullify) ++
+ Seq(regularGroupId) ++
+ regularAggChildren)
+ } else {
+ Seq.empty[Seq[Expression]]
+ }
+
+ // Construct the distinct aggregate input projections.
+ val regularAggNulls = regularAggChildren.map(nullify)
+ val distinctAggProjections = distinctAggOperatorMap.map {
+ case (projection, _) =>
+ a.groupingExpressions ++
+ projection ++
+ regularAggNulls
+ }
+
+ // Construct the expand operator.
+ val expand = Expand(
+ regularAggProjection ++ distinctAggProjections,
+ groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++
regularAggChildAttrMap.values.toSeq,
+ a.child)
+
+ // Construct the first aggregate operator. This de-duplicates the
all the children of
+ // distinct operators, and applies the regular aggregate operators.
+ val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+
gid
+ val firstAggregate = Aggregate(
+ firstAggregateGroupBy,
+ firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2),
+ expand)
+
+ // Construct the second aggregate
+ val transformations: Map[Expression, Expression] =
+ (distinctAggOperatorMap.flatMap(_._2) ++
+ regularAggOperatorMap.map(e => (e._1, e._3))).toMap
+
+ val patchedAggExpressions = a.aggregateExpressions.map { e =>
+ e.transformDown {
+ case e: Expression =>
+ // The same GROUP BY clauses can have different forms
(different names for instance) in
+ // the groupBy and aggregate expressions of an aggregate. This
makes a map lookup
+ // tricky. So we do a linear search for a semantically equal
group by expression.
--- End diff --
We've talked about adding an `ExpressionMap` similar to `AttributeMap` in
the past. Seems like that would be useful here.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]