Hello, 
  Spark sql rule RewriteDistinctAggregates will rewrite multiple distinct 
expressions into two Aggregate nodes and a expand node.
The follow is the example in the class documentation, I wander if we can 
reorder the second Aggregate node and the expand node and make the expand 
generate fewer records?
Thanks


Second example: aggregate function without distinct and with filter clauses (in 
sql):
   SELECT
     COUNT(DISTINCT cat1)as cat1_cnt,
     COUNT(DISTINCT cat2)as cat2_cnt,
     SUM(value) FILTER (WHERE id >1)AS total
  FROM
    data
  GROUPBY
    key

This translates to the following (pseudo) logical plan:

 Aggregate(
    key = ['key]
    functions = [COUNT(DISTINCT 'cat1),
                 COUNT(DISTINCT 'cat2),
                 sum('value) with FILTER('id > 1)]
    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
   LocalTableScan [...]

This rule rewrites this logical plan to the following (pseudo) logical plan:

 Aggregate(
    key = ['key]
    functions = [count(if (('gid = 1)) 'cat1 else null),
                 count(if (('gid = 2)) 'cat2 else null),
                 first(if (('gid = 0)) 'total else null) ignore nulls]
    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
   Aggregate(
      key = ['key, 'cat1, 'cat2, 'gid]
      functions = [sum('value) with FILTER('id > 1)]
      output = ['key, 'cat1, 'cat2, 'gid, 'total])
     Expand(
        projections = [('key, null, null, 0, cast('value as bigint), 'id),
                       ('key, 'cat1, null, 1, null, null),
                       ('key, null, 'cat2, 2, null, null)]
        output = ['key, 'cat1, 'cat2, 'gid, 'value, 'id])
       LocalTableScan [...]

Could we rewrite this logical plan to :

 Aggregate(
    key = ['key]
    functions = [count(if (('gid = 1)) 'cat1 else null),
                 count(if (('gid = 2)) 'cat2 else null),
                 first(if (('gid = 0)) 'total else null) ignore nulls]
    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
   Expand(
     projections = [('key, 'total, null, null, 0, cast('value as bigint)),
                    ('key, 'total, 'cat1, null, 1, null),
                    ('key, 'total, null, 'cat2, 2, null)]
     output = ['key, 'total, 'cat1, 'cat2, 'gid, 'value])
      Aggregate(
         key = ['key, 'cat1, 'cat2]
         functions = [sum('value) with FILTER('id > 1)]
         output = ['key, 'cat1, 'cat2, 'total])
       LocalTableScan [...]

Reply via email to