maropu commented on a change in pull request #27058: [SPARK-30276][SQL] Support 
Filter expression allows simultaneous use of DISTINCT
URL: https://github.com/apache/spark/pull/27058#discussion_r368808060
 
 

 ##########
 File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
 ##########
 @@ -148,24 +207,106 @@ object RewriteDistinctAggregates extends 
Rule[LogicalPlan] {
     val distinctAggs = exprs.flatMap { _.collect {
       case ae: AggregateExpression if ae.isDistinct => ae
     }}
-    // We need at least two distinct aggregates for this rule because 
aggregation
-    // strategy can handle a single distinct group.
+    // This rule serves two purposes:
+    // One is to rewrite when there exists at least two distinct aggregates. 
We need at least
+    // two distinct aggregates for this rule because aggregation strategy can 
handle a single
+    // distinct group.
+    // Another is to expand distinct aggregates which exists filter clause so 
that we can
+    // evaluate the filter locally.
     // This check can produce false-positives, e.g., SUM(DISTINCT a) & 
COUNT(DISTINCT a).
-    distinctAggs.size > 1
+    distinctAggs.size >= 1 || distinctAggs.exists(_.filter.isDefined)
   }
 
   def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
-    case a: Aggregate if mayNeedtoRewrite(a.aggregateExpressions) => rewrite(a)
+    case a: Aggregate if mayNeedtoRewrite(a.aggregateExpressions) =>
+      val expandAggregate = extractFiltersInDistinctAggregate(a)
+      rewriteDistinctAggregate(expandAggregate)
   }
 
-  def rewrite(a: Aggregate): Aggregate = {
+  private def extractFiltersInDistinctAggregate(a: Aggregate): Aggregate = {
+    val aggExpressions = collectAggregateExprs(a)
+    val (distinctAggExpressions, regularAggExpressions) = 
aggExpressions.partition(_.isDistinct)
+    if (distinctAggExpressions.exists(_.filter.isDefined)) {
+      // Setup expand for the 'regular' aggregate expressions. Because we will 
construct a new
+      // aggregate, the children of the distinct aggregates will be changed to 
the generate
+      // ones, so we need creates new references to avoid collisions between 
distinct and
+      // regular aggregate children.
+      val regularAggExprs = 
regularAggExpressions.filter(_.children.exists(!_.foldable))
+      val regularFunChildren = regularAggExprs
+        .flatMap(_.aggregateFunction.children.filter(!_.foldable))
+      val regularFilterAttrs = regularAggExprs.flatMap(_.filterAttributes)
+      val regularAggChildren = (regularFunChildren ++ 
regularFilterAttrs).distinct
+      val regularAggChildAttrMap = 
regularAggChildren.map(expressionAttributePair)
+      val regularAggChildAttrLookup = regularAggChildAttrMap.toMap
+      val regularAggMap = regularAggExprs.map {
+        case ae @ AggregateExpression(af, _, _, filter, _) =>
+          val newChildren = af.children.map(c => 
regularAggChildAttrLookup.getOrElse(c, c))
+          val raf = 
af.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
+          val filterOpt = filter.map(_.transform {
+            case a: Attribute => regularAggChildAttrLookup.getOrElse(a, a)
+          })
+          val aggExpr = ae.copy(aggregateFunction = raf, filter = filterOpt)
+          (ae, aggExpr)
+      }
 
-    // Collect all aggregate expressions.
-    val aggExpressions = a.aggregateExpressions.flatMap { e =>
-      e.collect {
-        case ae: AggregateExpression => ae
+      // Setup expand for the distinct aggregate expressions.
+      val distinctAggExprs = distinctAggExpressions.filter(e => 
e.children.exists(!_.foldable))
+      val (projections, expressionAttrs, aggExprPairs) = distinctAggExprs.map {
+        case ae @ AggregateExpression(af, _, _, filter, _) =>
+          // Why do we need to construct the `exprId` ?
+          // First, In order to reduce costs, it is better to handle the 
filter clause locally.
+          // e.g. COUNT (DISTINCT a) FILTER (WHERE id > 1), evaluate expression
+          // If(id > 1) 'a else null first, and use the result as output.
+          // Second, If at least two DISTINCT aggregate expression which may 
references the
+          // same attributes. We need to construct the generate attributes so 
as the output not
+          // lost. e.g. SUM (DISTINCT a), COUNT (DISTINCT a) FILTER (WHERE id 
> 1) will output
+          // attribute '_gen_distinct-1 and attribute '_gen_distinct-2 instead 
of two 'a.
+          // Note: We just need to illusion the expression with filter clause.
+          // The illusionary mechanism may result in multiple distinct 
aggregations uses
+          // different column, so we still need to call `rewrite`.
+          val exprId = NamedExpression.newExprId.id
+          val unfoldableChildren = af.children.filter(!_.foldable)
+          val exprAttrs = unfoldableChildren.map { e =>
+            (e, AttributeReference(s"_gen_distinct_$exprId", e.dataType, 
nullable = true)())
+          }
+          val exprAttrLookup = exprAttrs.toMap
+          val newChildren = af.children.map(c => exprAttrLookup.getOrElse(c, 
c))
+          val raf = 
af.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
+          val aggExpr = ae.copy(aggregateFunction = raf, filter = None)
+          // Expand projection
+          val projection = unfoldableChildren.map {
+            case e if filter.isDefined => If(filter.get, e, nullify(e))
+            case e => e
+          }
+          (projection, exprAttrs, (ae, aggExpr))
+      }.unzip3
+      val distinctAggChildAttrs = expressionAttrs.flatten.map(_._2)
+      val allAggAttrs = regularAggChildAttrMap.map(_._2) ++ 
distinctAggChildAttrs
+      // Construct the aggregate input projection.
+      val rewriteAggProjections =
+        Seq(a.groupingExpressions ++ regularAggChildren ++ projections.flatten)
+      val groupByMap = a.groupingExpressions.collect {
+        case ne: NamedExpression => ne -> ne.toAttribute
+        case e => e -> AttributeReference(e.sql, e.dataType, e.nullable)()
       }
+      val groupByAttrs = groupByMap.map(_._2)
+      // Construct the expand operator.
+      val expand = Expand(rewriteAggProjections, groupByAttrs ++ allAggAttrs, 
a.child)
+      val rewriteAggExprLookup = (aggExprPairs ++ regularAggMap).toMap
+      val patchedAggExpressions = a.aggregateExpressions.map { e =>
+        e.transformDown {
+          case ae: AggregateExpression => rewriteAggExprLookup.getOrElse(ae, 
ae)
+        }.asInstanceOf[NamedExpression]
+      }
+      val expandAggregate = Aggregate(groupByAttrs, patchedAggExpressions, 
expand)
+      expandAggregate
+    } else {
+      a
     }
+  }
+
+  private def rewriteDistinctAggregate(a: Aggregate): Aggregate = {
 
 Review comment:
   nit: `rewriteDistinctAggregate`->`rewriteDistinctAggregates`

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to