Github user marmbrus commented on a diff in the pull request:

    https://github.com/apache/spark/pull/9406#discussion_r44199453
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
 ---
    @@ -213,3 +216,178 @@ object Utils {
         case other => None
       }
     }
    +
    +/**
    + * This rule rewrites an aggregate query with multiple distinct clauses 
into an expanded double
    + * aggregation in which the regular aggregation expressions and every 
distinct clause is aggregated
    + * in a separate group. The results are then combined in a second 
aggregate.
    + *
    + * TODO Expression cannocalization
    + * TODO Eliminate foldable expressions from distinct clauses.
    + * TODO This eliminates all distinct expressions. We could safely pass one 
to the aggregate
    + *      operator. Perhaps this is a good thing? It is much simpler to plan 
later on...
    + */
    +object MultipleDistinctRewriter extends Rule[LogicalPlan] {
    +
    +  def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
    +    case a: Aggregate => rewrite(a)
    +    case p => p
    +  }
    +
    +  def rewrite(a: Aggregate): Aggregate = {
    +
    +    // Collect all aggregate expressions.
    +    val aggExpressions = a.aggregateExpressions.flatMap { e =>
    +      e.collect {
    +        case ae: AggregateExpression2 => ae
    +      }
    +    }
    +
    +    // Extract distinct aggregate expressions.
    +    val distinctAggGroups = aggExpressions
    +      .filter(_.isDistinct)
    +      .groupBy(_.aggregateFunction.children.toSet)
    +
    +    // Only continue to rewrite if there is more than one distinct group.
    +    if (distinctAggGroups.size > 1) {
    +      // Create the attributes for the grouping id and the group by clause.
    +      val gid = new AttributeReference("gid", IntegerType, false)()
    +      val groupByMap = a.groupingExpressions.collect {
    +        case ne: NamedExpression => ne -> ne.toAttribute
    +        case e => e -> new AttributeReference(e.prettyName, e.dataType, 
e.nullable)()
    +      }
    +      val groupByAttrs = groupByMap.map(_._2)
    +
    +      // Functions used to modify aggregate functions and their inputs.
    +      def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, 
id), e, nullify(e))
    +      def patchAggregateFunctionChildren(
    +          af: AggregateFunction2,
    +          id: Literal,
    +          attrs: Map[Expression, Expression]): AggregateFunction2 = {
    +        af.withNewChildren(af.children.map { case afc =>
    +          evalWithinGroup(id, attrs(afc))
    +        }).asInstanceOf[AggregateFunction2]
    +      }
    +
    +      // Setup unique distinct aggregate children.
    +      val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq
    +      val distinctAggChildAttrMap = 
distinctAggChildren.map(expressionAttributePair).toMap
    +      val distinctAggChildAttrs = distinctAggChildAttrMap.values.toSeq
    +
    +      // Setup expand & aggregate operators for distinct aggregate 
expressions.
    +      val distinctAggOperatorMap = 
distinctAggGroups.toSeq.zipWithIndex.map {
    +        case ((group, expressions), i) =>
    +          val id = Literal(i + 1)
    +
    +          // Expand projection
    +          val projection = distinctAggChildren.map {
    +            case e if group.contains(e) => e
    +            case e => nullify(e)
    +          } :+ id
    +
    +          // Final aggregate
    +          val operators = expressions.map { e =>
    +            val af = e.aggregateFunction
    +            val naf = patchAggregateFunctionChildren(af, id, 
distinctAggChildAttrMap)
    +            (e, e.copy(aggregateFunction = naf, isDistinct = false))
    +          }
    +
    +          (projection, operators)
    +      }
    +
    +      // Setup expand for the 'regular' aggregate expressions.
    +      val regularAggExprs = aggExpressions.filter(!_.isDistinct)
    +      val regularAggChildren = 
regularAggExprs.flatMap(_.aggregateFunction.children).distinct
    +      val regularAggChildAttrMap = 
regularAggChildren.map(expressionAttributePair).toMap
    +
    +      // Setup aggregates for 'regular' aggregate expressions.
    +      val regularGroupId = Literal(0)
    +      val regularAggOperatorMap = regularAggExprs.map { e =>
    +        // Perform the actual aggregation in the initial aggregate.
    +        val af = patchAggregateFunctionChildren(
    +          e.aggregateFunction,
    +          regularGroupId,
    +          regularAggChildAttrMap)
    +        val a = Alias(e.copy(aggregateFunction = af), e.toString)()
    +
    +        // Get the result of the first aggregate in the last aggregate.
    +        val b = AggregateExpression2(
    +          aggregate.First(evalWithinGroup(regularGroupId, a.toAttribute), 
Literal(true)),
    +          mode = Complete,
    +          isDistinct = false)
    +
    +        // Some aggregate functions (COUNT) have the special property that 
they can return a
    +        // non-null result without any input. We need to make sure we 
return a result in this case.
    +        val c = af.defaultResult match {
    +          case Some(lit) => Coalesce(Seq(b, lit))
    +          case None => b
    +        }
    +
    +        (e, a, c)
    +      }
    +
    +      // Construct the regular aggregate input projection only if we need 
one.
    +      val regularAggProjection = if (regularAggExprs.nonEmpty) {
    +        Seq(a.groupingExpressions ++
    +          distinctAggChildren.map(nullify) ++
    +          Seq(regularGroupId) ++
    +          regularAggChildren)
    +      } else {
    +        Seq.empty[Seq[Expression]]
    +      }
    +
    +      // Construct the distinct aggregate input projections.
    +      val regularAggNulls = regularAggChildren.map(nullify)
    +      val distinctAggProjections = distinctAggOperatorMap.map {
    +        case (projection, _) =>
    +          a.groupingExpressions ++
    +            projection ++
    +            regularAggNulls
    +      }
    +
    +      // Construct the expand operator.
    +      val expand = Expand(
    +        regularAggProjection ++ distinctAggProjections,
    +        groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ 
regularAggChildAttrMap.values.toSeq,
    +        a.child)
    +
    +      // Construct the first aggregate operator. This de-duplicates the 
all the children of
    +      // distinct operators, and applies the regular aggregate operators.
    +      val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ 
gid
    +      val firstAggregate = Aggregate(
    +        firstAggregateGroupBy,
    +        firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2),
    +        expand)
    +
    +      // Construct the second aggregate
    +      val transformations: Map[Expression, Expression] =
    +        (distinctAggOperatorMap.flatMap(_._2) ++
    +          regularAggOperatorMap.map(e => (e._1, e._3))).toMap
    +
    +      val patchedAggExpressions = a.aggregateExpressions.map { e =>
    +        e.transformDown {
    +          case e: Expression =>
    +            // The same GROUP BY clauses can have different forms 
(different names for instance) in
    +            // the groupBy and aggregate expressions of an aggregate. This 
makes a map lookup
    +            // tricky. So we do a linear search for a semantically equal 
group by expression.
    --- End diff --
    
    We've talked about adding an `ExpressionMap` similar to `AttributeMap` in 
the past.  Seems like that would be useful here.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to