karuppayya commented on a change in pull request #28804:
URL: https://github.com/apache/spark/pull/28804#discussion_r466748030
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
##########
@@ -409,6 +411,12 @@ case class HashAggregateExec(
private var fastHashMapTerm: String = _
private var isFastHashMapEnabled: Boolean = false
+ private var avoidSpillInPartialAggregateTerm: String = _
+ private val skipPartialAggregate = sqlContext.conf.skipPartialAggregate &&
+ AggUtils.areAggExpressionsPartial(modes) &&
find(_.isInstanceOf[ExpandExec]).isEmpty
Review comment:
This is required to avoid this optimization for Query with more than one
distinct.
`org.apache.spark.sql.catalyst.optimizer.RewriteDistinctAggregates` takes
cares of rewriting aggregates with more than one distinct.
The rule assumes that map side aggregation has taken care of performing
distinct operation.
With my change this will result in wrong results.
For example:
For the first example given as part of comments in the rule
```
* First example: query without filter clauses (in scala):
* {{{
* val data = Seq(
* ("a", "ca1", "cb1", 10),
* ("a", "ca1", "cb2", 5),
* ("b", "ca1", "cb1", 13))
* .toDF("key", "cat1", "cat2", "value")
* data.createOrReplaceTempView("data")
*
* val agg = data.groupBy($"key")
* .agg(
* countDistinct($"cat1").as("cat1_cnt"),
* countDistinct($"cat2").as("cat2_cnt"),
* sum($"value").as("total"))
* }}}
*
* This translates to the following (pseudo) logical plan:
* {{{
* Aggregate(
* key = ['key]
* functions = [COUNT(DISTINCT 'cat1),
* COUNT(DISTINCT 'cat2),
* sum('value)]
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
* LocalTableScan [...]
* }}}
*
* This rule rewrites this logical plan to the following (pseudo) logical
plan:
* {{{
* Aggregate(
* key = ['key]
* functions = [count(if (('gid = 1)) 'cat1 else null),
* count(if (('gid = 2)) 'cat2 else null),
* first(if (('gid = 0)) 'total else null) ignore nulls]
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
* Aggregate(
* key = ['key, 'cat1, 'cat2, 'gid]
* functions = [sum('value)]
* output = ['key, 'cat1, 'cat2, 'gid, 'total])
* Expand(
* projections = [('key, null, null, 0, cast('value as bigint)),
* ('key, 'cat1, null, 1, null),
* ('key, null, 'cat2, 2, null)]
* output = ['key, 'cat1, 'cat2, 'gid, 'value])
* LocalTableScan [...]
* }}}
*
```
Say the following are the two records in the dataset
```
rec 1: (“key1“, “cat1“, “cat1“, 1)
rec 2: (“key1“, “cat1“, “cat1“, 1)
```
With my change
After expand:
```
(“key1“, “null“, “null“, 0, 1)
(“key1“, “cat1“, “null“, 1, null)
(“key1“, “null“, “cat2“, 2, null)
(“key1“, “null“, “null“, 0, 1)
(“key1“, “cat1“, “null“, 1, null)
(“key1“, “null“, “cat1“, 2, null)
```
After partial aggregation
```
(“key1“, “null“, “null“, 0, 1)
(“key1“, “cat1“, “null“, 1, null)
(“key1“, “null“, “cat2“, 2, null)
(“key2“, “null“, “null“, 0, 1)
(“key2“, “cat2“, “null“, 1, null)
(“key2“, “null“, “cat2“, 2, null)
```
Reducer side aggregation result: **(key1, 2, 2, 2)**
But the correct answer is: **(key1, 1, 1, 2)**
Hence checking for the presence of expand node to avoid this skipping
partial aggregation
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]