maropu commented on a change in pull request #27058: [SPARK-30276][SQL] Support Filter expression allows simultaneous use of DISTINCT URL: https://github.com/apache/spark/pull/27058#discussion_r368808060
########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala ########## @@ -148,24 +207,106 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val distinctAggs = exprs.flatMap { _.collect { case ae: AggregateExpression if ae.isDistinct => ae }} - // We need at least two distinct aggregates for this rule because aggregation - // strategy can handle a single distinct group. + // This rule serves two purposes: + // One is to rewrite when there exists at least two distinct aggregates. We need at least + // two distinct aggregates for this rule because aggregation strategy can handle a single + // distinct group. + // Another is to expand distinct aggregates which exists filter clause so that we can + // evaluate the filter locally. // This check can produce false-positives, e.g., SUM(DISTINCT a) & COUNT(DISTINCT a). - distinctAggs.size > 1 + distinctAggs.size >= 1 || distinctAggs.exists(_.filter.isDefined) } def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case a: Aggregate if mayNeedtoRewrite(a.aggregateExpressions) => rewrite(a) + case a: Aggregate if mayNeedtoRewrite(a.aggregateExpressions) => + val expandAggregate = extractFiltersInDistinctAggregate(a) + rewriteDistinctAggregate(expandAggregate) } - def rewrite(a: Aggregate): Aggregate = { + private def extractFiltersInDistinctAggregate(a: Aggregate): Aggregate = { + val aggExpressions = collectAggregateExprs(a) + val (distinctAggExpressions, regularAggExpressions) = aggExpressions.partition(_.isDistinct) + if (distinctAggExpressions.exists(_.filter.isDefined)) { + // Setup expand for the 'regular' aggregate expressions. Because we will construct a new + // aggregate, the children of the distinct aggregates will be changed to the generate + // ones, so we need creates new references to avoid collisions between distinct and + // regular aggregate children. + val regularAggExprs = regularAggExpressions.filter(_.children.exists(!_.foldable)) + val regularFunChildren = regularAggExprs + .flatMap(_.aggregateFunction.children.filter(!_.foldable)) + val regularFilterAttrs = regularAggExprs.flatMap(_.filterAttributes) + val regularAggChildren = (regularFunChildren ++ regularFilterAttrs).distinct + val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair) + val regularAggChildAttrLookup = regularAggChildAttrMap.toMap + val regularAggMap = regularAggExprs.map { + case ae @ AggregateExpression(af, _, _, filter, _) => + val newChildren = af.children.map(c => regularAggChildAttrLookup.getOrElse(c, c)) + val raf = af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] + val filterOpt = filter.map(_.transform { + case a: Attribute => regularAggChildAttrLookup.getOrElse(a, a) + }) + val aggExpr = ae.copy(aggregateFunction = raf, filter = filterOpt) + (ae, aggExpr) + } - // Collect all aggregate expressions. - val aggExpressions = a.aggregateExpressions.flatMap { e => - e.collect { - case ae: AggregateExpression => ae + // Setup expand for the distinct aggregate expressions. + val distinctAggExprs = distinctAggExpressions.filter(e => e.children.exists(!_.foldable)) + val (projections, expressionAttrs, aggExprPairs) = distinctAggExprs.map { + case ae @ AggregateExpression(af, _, _, filter, _) => + // Why do we need to construct the `exprId` ? + // First, In order to reduce costs, it is better to handle the filter clause locally. + // e.g. COUNT (DISTINCT a) FILTER (WHERE id > 1), evaluate expression + // If(id > 1) 'a else null first, and use the result as output. + // Second, If at least two DISTINCT aggregate expression which may references the + // same attributes. We need to construct the generate attributes so as the output not + // lost. e.g. SUM (DISTINCT a), COUNT (DISTINCT a) FILTER (WHERE id > 1) will output + // attribute '_gen_distinct-1 and attribute '_gen_distinct-2 instead of two 'a. + // Note: We just need to illusion the expression with filter clause. + // The illusionary mechanism may result in multiple distinct aggregations uses + // different column, so we still need to call `rewrite`. + val exprId = NamedExpression.newExprId.id + val unfoldableChildren = af.children.filter(!_.foldable) + val exprAttrs = unfoldableChildren.map { e => + (e, AttributeReference(s"_gen_distinct_$exprId", e.dataType, nullable = true)()) + } + val exprAttrLookup = exprAttrs.toMap + val newChildren = af.children.map(c => exprAttrLookup.getOrElse(c, c)) + val raf = af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] + val aggExpr = ae.copy(aggregateFunction = raf, filter = None) + // Expand projection + val projection = unfoldableChildren.map { + case e if filter.isDefined => If(filter.get, e, nullify(e)) + case e => e + } + (projection, exprAttrs, (ae, aggExpr)) + }.unzip3 + val distinctAggChildAttrs = expressionAttrs.flatten.map(_._2) + val allAggAttrs = regularAggChildAttrMap.map(_._2) ++ distinctAggChildAttrs + // Construct the aggregate input projection. + val rewriteAggProjections = + Seq(a.groupingExpressions ++ regularAggChildren ++ projections.flatten) + val groupByMap = a.groupingExpressions.collect { + case ne: NamedExpression => ne -> ne.toAttribute + case e => e -> AttributeReference(e.sql, e.dataType, e.nullable)() } + val groupByAttrs = groupByMap.map(_._2) + // Construct the expand operator. + val expand = Expand(rewriteAggProjections, groupByAttrs ++ allAggAttrs, a.child) + val rewriteAggExprLookup = (aggExprPairs ++ regularAggMap).toMap + val patchedAggExpressions = a.aggregateExpressions.map { e => + e.transformDown { + case ae: AggregateExpression => rewriteAggExprLookup.getOrElse(ae, ae) + }.asInstanceOf[NamedExpression] + } + val expandAggregate = Aggregate(groupByAttrs, patchedAggExpressions, expand) + expandAggregate + } else { + a } + } + + private def rewriteDistinctAggregate(a: Aggregate): Aggregate = { Review comment: nit: `rewriteDistinctAggregate`->`rewriteDistinctAggregates` ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org