xiedeyantu commented on code in PR #4495:
URL: https://github.com/apache/calcite/pull/4495#discussion_r2291302059
##########
core/src/main/java/org/apache/calcite/tools/RelBuilder.java:
##########
@@ -2817,31 +2817,52 @@ private RelBuilder aggregate_(ImmutableBitSet groupSet,
*
* <p>Also see the discussion in
* <a
href="https://issues.apache.org/jira/browse/CALCITE-1824">[CALCITE-1824]
- * GROUP_ID returns wrong result</a> and
+ * GROUP_ID returns wrong result</a>,
* <a
href="https://issues.apache.org/jira/browse/CALCITE-4748">[CALCITE-4748]
* If there are duplicate GROUPING SETS, Calcite should return duplicate
- * rows</a>.
+ * rows</a> and
+ * <a
href="https://issues.apache.org/jira/browse/CALCITE-7126">[CALCITE-7126]
+ * The calculation result of grouping function is wrong</a>.
*/
private RelBuilder rewriteAggregateWithDuplicateGroupSets(
ImmutableBitSet groupSet,
ImmutableSortedMultiset<ImmutableBitSet> groupSets,
List<AggCallPlus> aggregateCalls) {
+ List<AggregateCall> calls =
+ aggregateCalls.stream()
+ .map(AggCallPlus::aggregateCall)
+ .collect(Collectors.toList());
+ return rewriteAggregateWithDuplicateGroupSetsByAggregateCall(groupSet,
groupSets, calls);
+ }
+
+ private RelBuilder rewriteAggregateWithDuplicateGroupSetsByAggregateCall(
+ ImmutableBitSet groupSet,
+ ImmutableSortedMultiset<ImmutableBitSet> groupSets,
+ List<AggregateCall> aggregateCalls) {
+ final int groupCount = groupSet.cardinality();
final List<String> fieldNamesIfNoRewrite =
- Aggregate.deriveRowType(getTypeFactory(), peek().getRowType(), false,
- groupSet, groupSets.asList(),
- aggregateCalls.stream().map(AggCallPlus::aggregateCall)
- .collect(toImmutableList())).getFieldNames();
-
- // If n duplicates exist for a particular grouping, the {@code GROUP_ID()}
- // function produces values in the range 0 to n-1. For each value,
- // we need to figure out the corresponding group sets.
- //
- // For example, "... GROUPING SETS (a, a, b, c, c, c, c)"
- // (i) The max value of the GROUP_ID() function returns is 3
- // (ii) GROUPING SETS (a, b, c) produces value 0,
- // GROUPING SETS (a, c) produces value 1,
- // GROUPING SETS (c) produces value 2
- // GROUPING SETS (c) produces value 3
+ Aggregate.deriveRowType(getTypeFactory(), peek().getRowType(),
+ false, groupSet, groupSets.asList(),
aggregateCalls).getFieldNames();
+
+ Mappings.TargetMapping mapping =
+ Mappings.target(groupSet.toList(),
peek().getRowType().getFieldCount());
+
+ final Frame input = stack.pop();
+
+ // Only expand fully when GROUPING functions exist
+ boolean hasGroupingFunction = aggregateCalls.stream()
+ .anyMatch(call -> call.getAggregation().getKind() == SqlKind.GROUPING);
+
+ if (hasGroupingFunction) {
+ Map<ImmutableBitSet, Integer> groupSetToGroupId = new HashMap<>();
+ for (ImmutableBitSet gs : groupSets) {
+ int groupId = groupSetToGroupId.compute(gs, (k, v) -> v == null ? 0 :
v + 1);
+ rewriteGroupAggCalls(ImmutableSet.of(gs), groupSet, aggregateCalls,
groupCount,
+ fieldNamesIfNoRewrite, mapping, input, groupId);
+ }
+ return union(true, groupSets.size());
+ }
+
final Map<Integer, Set<ImmutableBitSet>> groupIdToGroupSets = new
HashMap<>();
int maxGroupId = 0;
Review Comment:
Has removed
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]