yjshen commented on issue #5547:
URL: 
https://github.com/apache/arrow-datafusion/issues/5547#issuecomment-1464255278

   I was suggesting using an optimizer rule to rewrite aggregate with distinct 
into double aggregation to eliminate distinct `AggregateExpr`s for execution. 
   
   The gist of the idea is to first move distinct columns as additional 
grouping columns to compute non-distinct aggregate results, and then use 
another round of aggregation to compute values for distinct expressions (since 
they have already been deduplicated in the first aggregation as grouping 
columns).
   
   I will paste JavaDoc for Spark's RewriteDistinctAggregates below because it 
contains helpful examples, [source 
here](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala)
   
   ```java
   /**
    * This rule rewrites an aggregate query with distinct aggregations into an 
expanded double
    * aggregation in which the regular aggregation expressions and every 
distinct clause is aggregated
    * in a separate group. The results are then combined in a second aggregate.
    *
    * First example: query without filter clauses (in scala):
    * {{{
    *   val data = Seq(
    *     ("a", "ca1", "cb1", 10),
    *     ("a", "ca1", "cb2", 5),
    *     ("b", "ca1", "cb1", 13))
    *     .toDF("key", "cat1", "cat2", "value")
    *   data.createOrReplaceTempView("data")
    *
    *   val agg = data.groupBy($"key")
    *     .agg(
    *       count_distinct($"cat1").as("cat1_cnt"),
    *       count_distinct($"cat2").as("cat2_cnt"),
    *       sum($"value").as("total"))
    * }}}
    *
    * This translates to the following (pseudo) logical plan:
    * {{{
    * Aggregate(
    *    key = ['key]
    *    functions = [COUNT(DISTINCT 'cat1),
    *                 COUNT(DISTINCT 'cat2),
    *                 sum('value)]
    *    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
    *   LocalTableScan [...]
    * }}}
    *
    * This rule rewrites this logical plan to the following (pseudo) logical 
plan:
    * {{{
    * Aggregate(
    *    key = ['key]
    *    functions = [count('cat1) FILTER (WHERE 'gid = 1),
    *                 count('cat2) FILTER (WHERE 'gid = 2),
    *                 first('total) ignore nulls FILTER (WHERE 'gid = 0)]
    *    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
    *   Aggregate(
    *      key = ['key, 'cat1, 'cat2, 'gid]
    *      functions = [sum('value)]
    *      output = ['key, 'cat1, 'cat2, 'gid, 'total])
    *     Expand(
    *        projections = [('key, null, null, 0, cast('value as bigint)),
    *                       ('key, 'cat1, null, 1, null),
    *                       ('key, null, 'cat2, 2, null)]
    *        output = ['key, 'cat1, 'cat2, 'gid, 'value])
    *       LocalTableScan [...]
    * }}}
    *
    * Second example: aggregate function without distinct and with filter 
clauses (in sql):
    * {{{
    *   SELECT
    *     COUNT(DISTINCT cat1) as cat1_cnt,
    *     COUNT(DISTINCT cat2) as cat2_cnt,
    *     SUM(value) FILTER (WHERE id > 1) AS total
    *   FROM
    *     data
    *   GROUP BY
    *     key
    * }}}
    *
    * This translates to the following (pseudo) logical plan:
    * {{{
    * Aggregate(
    *    key = ['key]
    *    functions = [COUNT(DISTINCT 'cat1),
    *                 COUNT(DISTINCT 'cat2),
    *                 sum('value) FILTER (WHERE 'id > 1)]
    *    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
    *   LocalTableScan [...]
    * }}}
    *
    * This rule rewrites this logical plan to the following (pseudo) logical 
plan:
    * {{{
    * Aggregate(
    *    key = ['key]
    *    functions = [count('cat1) FILTER (WHERE 'gid = 1),
    *                 count('cat2) FILTER (WHERE 'gid = 2),
    *                 first('total) ignore nulls FILTER (WHERE 'gid = 0)]
    *    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
    *   Aggregate(
    *      key = ['key, 'cat1, 'cat2, 'gid]
    *      functions = [sum('value) FILTER (WHERE 'id > 1)]
    *      output = ['key, 'cat1, 'cat2, 'gid, 'total])
    *     Expand(
    *        projections = [('key, null, null, 0, cast('value as bigint), 'id),
    *                       ('key, 'cat1, null, 1, null, null),
    *                       ('key, null, 'cat2, 2, null, null)]
    *        output = ['key, 'cat1, 'cat2, 'gid, 'value, 'id])
    *       LocalTableScan [...]
    * }}}
    *
    * Third example: aggregate function with distinct and filter clauses (in 
sql):
    * {{{
    *   SELECT
    *     COUNT(DISTINCT cat1) FILTER (WHERE id > 1) as cat1_cnt,
    *     COUNT(DISTINCT cat2) FILTER (WHERE id > 2) as cat2_cnt,
    *     SUM(value) FILTER (WHERE id > 3) AS total
    *   FROM
    *     data
    *   GROUP BY
    *     key
    * }}}
    *
    * This translates to the following (pseudo) logical plan:
    * {{{
    * Aggregate(
    *    key = ['key]
    *    functions = [COUNT(DISTINCT 'cat1) FILTER (WHERE 'id > 1),
    *                 COUNT(DISTINCT 'cat2) FILTER (WHERE 'id > 2),
    *                 sum('value) FILTER (WHERE 'id > 3)]
    *    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
    *   LocalTableScan [...]
    * }}}
    *
    * This rule rewrites this logical plan to the following (pseudo) logical 
plan:
    * {{{
    * Aggregate(
    *    key = ['key]
    *    functions = [count('cat1) FILTER (WHERE 'gid = 1 and 'max_cond1),
    *                 count('cat2) FILTER (WHERE 'gid = 2 and 'max_cond2),
    *                 first('total) ignore nulls FILTER (WHERE 'gid = 0)]
    *    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
    *   Aggregate(
    *      key = ['key, 'cat1, 'cat2, 'gid]
    *      functions = [max('cond1), max('cond2), sum('value) FILTER (WHERE 'id 
> 3)]
    *      output = ['key, 'cat1, 'cat2, 'gid, 'max_cond1, 'max_cond2, 'total])
    *     Expand(
    *        projections = [('key, null, null, 0, null, null, cast('value as 
bigint), 'id),
    *                       ('key, 'cat1, null, 1, 'id > 1, null, null, null),
    *                       ('key, null, 'cat2, 2, null, 'id > 2, null, null)]
    *        output = ['key, 'cat1, 'cat2, 'gid, 'cond1, 'cond2, 'value, 'id])
    *       LocalTableScan [...]
    * }}}
   ```
   
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to