c21 commented on a change in pull request #32242:
URL: https://github.com/apache/spark/pull/32242#discussion_r616468793
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
##########
@@ -128,6 +128,16 @@ case class HashAggregateExec(
// all the mode of aggregate expressions
private val modes = aggregateExpressions.map(_.mode).distinct
+ // This is for testing final aggregate with number-of-rows-based fall back
as specified in
+ // `testFallbackStartsAt`. In this scenario, there might be same keys exist
in both fast and
+ // regular hash map. So the aggregation buffers from both maps need to be
merged together
+ // to avoid correctness issue.
+ //
+ // This scenario only happens in unit test with number-of-rows-based fall
back.
+ // There should not be same keys in both maps with size-based fall back in
production.
+ private val isTestFinalAggregateWithFallback: Boolean =
testFallbackStartsAt.isDefined &&
Review comment:
@cloud-fan - sure. This is how number-of-rows-based fallback works.
With an internal config `spark.sql.TungstenAggregate.testFallbackStartsAt`,
we can set (1). when to fallback from first level hash map to second level hash
map, and (2). when to fallback from second level hash map to sort.
Suppose `spark.sql.TungstenAggregate.testFallbackStartsAt` = "2, 3".
Then the generated code per input row (aggregate the row into hash map)
looks like:
```
UnsafeRow agg_buffer = null;
if (counter < 2) {
// 1st level hash map
agg_buffer = fastHashMap.findOrInsert(key);
}
if (agg_buffer == null) {
// generated. code for key in unsafe row format
...
if (counter < 3) {
// 2nd level hash map
agg_buffer =
regularHashMap.getAggregationBufferFromUnsafeRow(key_in_unsafe_row, ...);
}
if (agg_buffer == null) {
// sort-based fallback
regularHashMap.destructAndCreateExternalSorter();
...
counter = 0;
}
}
counter += 1;
```
Example generated code is Line 187-232 in
https://gist.github.com/c21/d0f704c0a33c24ec05387ff4df438bff .
I tried to add a method `fastHashMap.find(key): boolean`, and change code
like this:
```
...
if (fastHashMap.find(key) || counter < 2) {
// 1st level hash map
agg_buffer = fastHashMap.findOrInsert(key);
}
...
```
But I later found the case as I mentioned above:
1. key(a) is inserted into second level hash map (when counter exceeds 1st
threshold)
2. sort-based fallback happens, and counter is reset to 0 (when counter
exceeds 2nd threshold)
3. key(a) is not in first level hash map, and counter does not exceed 1st
threshold, the key(a) is inserted into first level hash map as well by mistake.
We can further add code like this:
```
if ((fastHashMap.find(key) && !regularHashMap.find(key_in_unsafe_row)) ||
counter < 2) {
// 1st level hash map
agg_buffer = fastHashMap.findOrInsert(key);
}
```
But it introduces more ad-hoc change and looks pretty ugly with a lot of
code needs to be moved.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]