xiedeyantu commented on code in PR #4636:
URL: https://github.com/apache/calcite/pull/4636#discussion_r2543990106
##########
core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java:
##########
@@ -520,43 +609,207 @@ private static void
rewriteUsingGroupingSets(RelOptRuleCall call,
relBuilder.project(nodes);
}
- int x = groupCount;
- final ImmutableBitSet groupSet = aggregate.getGroupSet();
- final List<AggregateCall> newCalls = new ArrayList<>();
- for (AggregateCall aggCall : aggregate.getAggCallList()) {
- final int newFilterArg;
- final List<Integer> newArgList;
- final SqlAggFunction aggregation;
+ // Compute the remapped top-group key and grouping sets. The top-group key
+ // selects which fields of the bottom result correspond to the original
+ // aggregate's group-by columns. Upper aggregates(line 3) will group by
this key.
+ final ImmutableBitSet topGroupKey = remap(fullGroupSet, aggregateGroupSet);
+ final ImmutableList<ImmutableBitSet> topGroupingSets =
+ remap(fullGroupSet, aggregate.getGroupSets());
+ final int topGroupCount = topGroupKey.cardinality();
+ final boolean needsGroupingIndicators = aggregate.getGroupType() !=
Group.SIMPLE;
+ final List<Integer> groupingIndicatorOrdinals;
+ if (needsGroupingIndicators) {
+ groupingIndicatorOrdinals =
+ new ArrayList<>(Collections.nCopies(aggregateGroupingSets.size(),
-1));
+ } else {
+ groupingIndicatorOrdinals = ImmutableList.of();
+ }
+
+ int valueIndex = bottomGroupCount;
+ // line 3 will be built from this list
+ final List<AggregateCall> upperAggCalls = new ArrayList<>();
+ final List<List<Integer>> aggCallOrdinals = new ArrayList<>();
+ final List<AggregateCall> aggCalls = aggregate.getAggCallList();
+
+ // The first part of line 3: Build upper aggregates per declared grouping
set.
+ // For each original aggCall we create one upper agg per declared grouping
set.
+ // The upper aggregate groups by {@code topGroupKey} and uses the boolean
marker
+ // columns (placed at known ordinals) as the FILTER argument for the
+ // corresponding per-group aggregation. The list {@code aggCallOrdinals}
+ // records, for each original aggCall, the output field ordinals of the
+ // corresponding upper-aggregate results (one per grouping set).
+ for (AggregateCall aggCall : aggCalls) {
+ final List<Integer> ordinals = new ArrayList<>();
if (!aggCall.isDistinct()) {
- aggregation = SqlStdOperatorTable.MIN;
- newArgList = ImmutableIntList.of(x++);
- newFilterArg =
- requireNonNull(filters.get(Pair.of(groupSet, -1)),
- "filters.get(Pair.of(groupSet, -1))");
+ final int inputIndex = valueIndex++;
+ final List<Integer> args = ImmutableIntList.of(inputIndex);
+ for (int g = 0; g < aggregateGroupingSets.size(); g++) {
+ final ImmutableBitSet groupingSet = aggregateGroupingSets.get(g);
+ final int newFilterArg =
+ requireNonNull(filters.get(Pair.of(groupingSet, -1)),
+ () -> "filters.get(" + groupingSet + ", -1)");
+ final String upperAggName = upperAggCallName(aggCall, g);
+ // Each filtered grouping set emits at most one row, so MIN just
Review Comment:
My apologies, I should have reminded you of the comment's location sooner.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]