Github user hvanhovell commented on a diff in the pull request:

    https://github.com/apache/spark/pull/9406#discussion_r44203639
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
 ---
    @@ -213,3 +216,178 @@ object Utils {
         case other => None
       }
     }
    +
    +/**
    + * This rule rewrites an aggregate query with multiple distinct clauses 
into an expanded double
    + * aggregation in which the regular aggregation expressions and every 
distinct clause is aggregated
    + * in a separate group. The results are then combined in a second 
aggregate.
    + *
    + * TODO Expression cannocalization
    + * TODO Eliminate foldable expressions from distinct clauses.
    + * TODO This eliminates all distinct expressions. We could safely pass one 
to the aggregate
    + *      operator. Perhaps this is a good thing? It is much simpler to plan 
later on...
    + */
    +object MultipleDistinctRewriter extends Rule[LogicalPlan] {
    +
    +  def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
    +    case a: Aggregate => rewrite(a)
    +    case p => p
    +  }
    +
    +  def rewrite(a: Aggregate): Aggregate = {
    +
    +    // Collect all aggregate expressions.
    +    val aggExpressions = a.aggregateExpressions.flatMap { e =>
    +      e.collect {
    +        case ae: AggregateExpression2 => ae
    +      }
    +    }
    +
    +    // Extract distinct aggregate expressions.
    +    val distinctAggGroups = aggExpressions
    +      .filter(_.isDistinct)
    +      .groupBy(_.aggregateFunction.children.toSet)
    +
    +    // Only continue to rewrite if there is more than one distinct group.
    +    if (distinctAggGroups.size > 1) {
    +      // Create the attributes for the grouping id and the group by clause.
    +      val gid = new AttributeReference("gid", IntegerType, false)()
    +      val groupByMap = a.groupingExpressions.collect {
    +        case ne: NamedExpression => ne -> ne.toAttribute
    +        case e => e -> new AttributeReference(e.prettyName, e.dataType, 
e.nullable)()
    +      }
    +      val groupByAttrs = groupByMap.map(_._2)
    +
    +      // Functions used to modify aggregate functions and their inputs.
    +      def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, 
id), e, nullify(e))
    +      def patchAggregateFunctionChildren(
    +          af: AggregateFunction2,
    +          id: Literal,
    +          attrs: Map[Expression, Expression]): AggregateFunction2 = {
    +        af.withNewChildren(af.children.map { case afc =>
    +          evalWithinGroup(id, attrs(afc))
    +        }).asInstanceOf[AggregateFunction2]
    +      }
    +
    +      // Setup unique distinct aggregate children.
    +      val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq
    +      val distinctAggChildAttrMap = 
distinctAggChildren.map(expressionAttributePair).toMap
    +      val distinctAggChildAttrs = distinctAggChildAttrMap.values.toSeq
    +
    +      // Setup expand & aggregate operators for distinct aggregate 
expressions.
    +      val distinctAggOperatorMap = 
distinctAggGroups.toSeq.zipWithIndex.map {
    +        case ((group, expressions), i) =>
    +          val id = Literal(i + 1)
    +
    +          // Expand projection
    +          val projection = distinctAggChildren.map {
    +            case e if group.contains(e) => e
    +            case e => nullify(e)
    +          } :+ id
    +
    +          // Final aggregate
    +          val operators = expressions.map { e =>
    +            val af = e.aggregateFunction
    +            val naf = patchAggregateFunctionChildren(af, id, 
distinctAggChildAttrMap)
    +            (e, e.copy(aggregateFunction = naf, isDistinct = false))
    +          }
    +
    +          (projection, operators)
    +      }
    +
    +      // Setup expand for the 'regular' aggregate expressions.
    +      val regularAggExprs = aggExpressions.filter(!_.isDistinct)
    +      val regularAggChildren = 
regularAggExprs.flatMap(_.aggregateFunction.children).distinct
    +      val regularAggChildAttrMap = 
regularAggChildren.map(expressionAttributePair).toMap
    +
    +      // Setup aggregates for 'regular' aggregate expressions.
    +      val regularGroupId = Literal(0)
    +      val regularAggOperatorMap = regularAggExprs.map { e =>
    --- End diff --
    
    I'll add documentation in a follow-up PR.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to