mihaibudiu commented on code in PR #4636:
URL: https://github.com/apache/calcite/pull/4636#discussion_r2542936362
##########
core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java:
##########
@@ -520,43 +609,203 @@ private static void
rewriteUsingGroupingSets(RelOptRuleCall call,
relBuilder.project(nodes);
}
- int x = groupCount;
- final ImmutableBitSet groupSet = aggregate.getGroupSet();
- final List<AggregateCall> newCalls = new ArrayList<>();
- for (AggregateCall aggCall : aggregate.getAggCallList()) {
- final int newFilterArg;
- final List<Integer> newArgList;
- final SqlAggFunction aggregation;
+ // Compute the remapped top-group key and grouping sets. The top-group key
+ // selects which fields of the bottom result correspond to the original
+ // aggregate's group-by columns. Upper aggregates will group by this key.
+ final ImmutableBitSet topGroupKey = remap(fullGroupSet, aggregateGroupSet);
+ final ImmutableList<ImmutableBitSet> topGroupingSets =
+ remap(fullGroupSet, aggregate.getGroupSets());
+ final int topGroupCount = topGroupKey.cardinality();
+ final boolean needsGroupingIndicators = aggregate.getGroupType() !=
Group.SIMPLE;
+ final List<Integer> groupingIndicatorOrdinals;
+ if (needsGroupingIndicators) {
+ groupingIndicatorOrdinals =
+ new ArrayList<>(Collections.nCopies(aggregateGroupingSets.size(),
-1));
+ } else {
+ groupingIndicatorOrdinals = ImmutableList.of();
+ }
+
+ int valueIndex = bottomGroupCount;
+ final List<AggregateCall> upperAggCalls = new ArrayList<>();
+ final List<List<Integer>> aggCallOrdinals = new ArrayList<>();
+ final List<AggregateCall> aggCalls = aggregate.getAggCallList();
+
+ // Build upper aggregates per declared grouping set. For each original
+ // aggCall we create one upper agg per declared grouping set. The upper
+ // aggregate groups by {@code topGroupKey} and uses the boolean marker
+ // columns (placed at known ordinals) as the FILTER argument for the
+ // corresponding per-group aggregation. The list {@code aggCallOrdinals}
+ // records, for each original aggCall, the output field ordinals of the
+ // corresponding upper-aggregate results (one per grouping set).
+ for (AggregateCall aggCall : aggCalls) {
+ final List<Integer> ordinals = new ArrayList<>();
if (!aggCall.isDistinct()) {
- aggregation = SqlStdOperatorTable.MIN;
- newArgList = ImmutableIntList.of(x++);
- newFilterArg =
- requireNonNull(filters.get(Pair.of(groupSet, -1)),
- "filters.get(Pair.of(groupSet, -1))");
+ final int inputIndex = valueIndex++;
+ final List<Integer> args = ImmutableIntList.of(inputIndex);
+ for (int g = 0; g < aggregateGroupingSets.size(); g++) {
+ final ImmutableBitSet groupingSet = aggregateGroupingSets.get(g);
+ final int newFilterArg =
+ requireNonNull(filters.get(Pair.of(groupingSet, -1)),
+ () -> "filters.get(" + groupingSet + ", -1)");
+ final String upperAggName = upperAggCallName(aggCall, g);
+ final AggregateCall newCall =
+ AggregateCall.create(aggCall.getParserPosition(),
+ SqlStdOperatorTable.MIN, false, aggCall.isApproximate(),
+ aggCall.ignoreNulls(), aggCall.rexList, args, newFilterArg,
+ aggCall.distinctKeys, aggCall.collation,
aggregate.hasEmptyGroup(),
+ relBuilder.peek(), null, upperAggName);
+ upperAggCalls.add(newCall);
+ ordinals.add(topGroupCount + upperAggCalls.size() - 1);
+ }
} else {
- aggregation = aggCall.getAggregation();
- newArgList = remap(fullGroupSet, aggCall.getArgList());
- final ImmutableBitSet newGroupSet =
ImmutableBitSet.of(aggCall.getArgList())
- .setIf(aggCall.filterArg, aggCall.filterArg >= 0)
- .union(groupSet);
- newFilterArg =
- requireNonNull(filters.get(Pair.of(newGroupSet,
aggCall.filterArg)),
- "filters.get(of(newGroupSet, aggCall.filterArg))");
+ final List<Integer> newArgList = remap(fullGroupSet,
aggCall.getArgList());
+ for (int g = 0; g < aggregateGroupingSets.size(); g++) {
+ final ImmutableBitSet groupingSet = aggregateGroupingSets.get(g);
+ final ImmutableBitSet newGroupSet =
ImmutableBitSet.of(aggCall.getArgList())
+ .setIf(aggCall.filterArg, aggCall.filterArg >= 0)
+ .union(groupingSet);
+ final int newFilterArg =
+ requireNonNull(filters.get(Pair.of(newGroupSet,
aggCall.filterArg)),
+ () -> "filters.get(" + newGroupSet + ", " +
aggCall.filterArg + ")");
+ final String upperAggName = upperAggCallName(aggCall, g);
+ final AggregateCall newCall =
+ AggregateCall.create(aggCall.getParserPosition(),
aggCall.getAggregation(), false,
+ aggCall.isApproximate(), aggCall.ignoreNulls(),
+ aggCall.rexList, newArgList, newFilterArg,
+ aggCall.distinctKeys, aggCall.collation,
+ aggregate.hasEmptyGroup(), relBuilder.peek(), null,
upperAggName);
+ upperAggCalls.add(newCall);
+ ordinals.add(topGroupCount + upperAggCalls.size() - 1);
+ }
}
- final AggregateCall newCall =
- AggregateCall.create(aggCall.getParserPosition(), aggregation, false,
- aggCall.isApproximate(), aggCall.ignoreNulls(),
- aggCall.rexList, newArgList, newFilterArg,
- aggCall.distinctKeys, aggCall.collation,
- aggregate.hasEmptyGroup(), relBuilder.peek(), null,
aggCall.name);
- newCalls.add(newCall);
+ aggCallOrdinals.add(ordinals);
}
+ // If grouping indicators are needed (ROLLUP/CUBE/GROUPING SETS with more
+ // than one grouping set), add COUNT(...) presence calls which are later
+ // used to determine whether a grouping set produced any rows. These are
+ // used to implement semantics where empty grouping sets still must
+ // produce a result.
Review Comment:
again part of line 3
##########
core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java:
##########
@@ -520,43 +609,203 @@ private static void
rewriteUsingGroupingSets(RelOptRuleCall call,
relBuilder.project(nodes);
}
- int x = groupCount;
- final ImmutableBitSet groupSet = aggregate.getGroupSet();
- final List<AggregateCall> newCalls = new ArrayList<>();
- for (AggregateCall aggCall : aggregate.getAggCallList()) {
- final int newFilterArg;
- final List<Integer> newArgList;
- final SqlAggFunction aggregation;
+ // Compute the remapped top-group key and grouping sets. The top-group key
+ // selects which fields of the bottom result correspond to the original
+ // aggregate's group-by columns. Upper aggregates will group by this key.
Review Comment:
upper aggregates = line 3
##########
core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java:
##########
@@ -520,43 +609,203 @@ private static void
rewriteUsingGroupingSets(RelOptRuleCall call,
relBuilder.project(nodes);
}
- int x = groupCount;
- final ImmutableBitSet groupSet = aggregate.getGroupSet();
- final List<AggregateCall> newCalls = new ArrayList<>();
- for (AggregateCall aggCall : aggregate.getAggCallList()) {
- final int newFilterArg;
- final List<Integer> newArgList;
- final SqlAggFunction aggregation;
+ // Compute the remapped top-group key and grouping sets. The top-group key
+ // selects which fields of the bottom result correspond to the original
+ // aggregate's group-by columns. Upper aggregates will group by this key.
+ final ImmutableBitSet topGroupKey = remap(fullGroupSet, aggregateGroupSet);
+ final ImmutableList<ImmutableBitSet> topGroupingSets =
+ remap(fullGroupSet, aggregate.getGroupSets());
+ final int topGroupCount = topGroupKey.cardinality();
+ final boolean needsGroupingIndicators = aggregate.getGroupType() !=
Group.SIMPLE;
+ final List<Integer> groupingIndicatorOrdinals;
+ if (needsGroupingIndicators) {
+ groupingIndicatorOrdinals =
+ new ArrayList<>(Collections.nCopies(aggregateGroupingSets.size(),
-1));
+ } else {
+ groupingIndicatorOrdinals = ImmutableList.of();
+ }
+
+ int valueIndex = bottomGroupCount;
+ final List<AggregateCall> upperAggCalls = new ArrayList<>();
Review Comment:
line 3 will be build from this list
##########
core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java:
##########
@@ -425,39 +426,118 @@ private static RelBuilder
convertSingletonDistinct(RelBuilder relBuilder,
return relBuilder;
}
+ /**
+ * Rewrite aggregates that use GROUPING SETS. The following SQL/plan example
+ * serves as the concrete blueprint, starting from the original statement and
+ * plan-before outputs and then rebuilding the plan-after tree from the
bottom
+ * (line 7) back to the top (line 1):
+ *
+ * <p>Original SQL:
+ * <pre>{@code
+ * SELECT deptno, COUNT(DISTINCT sal)
+ * FROM emp
+ * GROUP BY ROLLUP(deptno)
+ * }</pre>
+ *
+ * <p>Plan before rewrite:
+ * <pre>{@code
+ * LogicalAggregate(group=[{0}], groups=[[{0}, {}]], EXPR$1=[COUNT(DISTINCT
$1)])
+ * LogicalProject(DEPTNO=[$7], SAL=[$5])
+ * LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ * }</pre>
+ *
+ * <p>Plan after rewrite (lines referenced below):
+ * <pre>{@code
+ * 1 LogicalProject(DEPTNO=[$0],
+ * EXPR$1=[CAST(CASE(=($5, 0), $1, =($5, 1), $2,
+ * null:BIGINT)):BIGINT NOT NULL])
+ * 2 LogicalFilter(condition=[OR(
+ * AND(=($5, 0), >($3, 0)), =($5, 1))])
+ * 3 LogicalAggregate(group=[{0}], groups=[[{0}, {}]],
+ * EXPR$1_g0=[COUNT($1) FILTER $2],
+ * EXPR$1_g1=[COUNT($1) FILTER $4],
+ * $g_present_0=[COUNT() FILTER $3],
+ * $g_present_1=[COUNT() FILTER $5],
+ * $g_final=[GROUPING($0)])
+ * 4 LogicalProject(DEPTNO=[$0], SAL=[$1],
Review Comment:
does this always contain all subsets of columns involved (2^n)?
it could become large.
but we could leave an optimization for this for a later pr.
or maybe pushdown for the outer project will take care of it?
##########
core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java:
##########
@@ -520,43 +609,203 @@ private static void
rewriteUsingGroupingSets(RelOptRuleCall call,
relBuilder.project(nodes);
}
- int x = groupCount;
- final ImmutableBitSet groupSet = aggregate.getGroupSet();
- final List<AggregateCall> newCalls = new ArrayList<>();
- for (AggregateCall aggCall : aggregate.getAggCallList()) {
- final int newFilterArg;
- final List<Integer> newArgList;
- final SqlAggFunction aggregation;
+ // Compute the remapped top-group key and grouping sets. The top-group key
+ // selects which fields of the bottom result correspond to the original
+ // aggregate's group-by columns. Upper aggregates will group by this key.
+ final ImmutableBitSet topGroupKey = remap(fullGroupSet, aggregateGroupSet);
+ final ImmutableList<ImmutableBitSet> topGroupingSets =
+ remap(fullGroupSet, aggregate.getGroupSets());
+ final int topGroupCount = topGroupKey.cardinality();
+ final boolean needsGroupingIndicators = aggregate.getGroupType() !=
Group.SIMPLE;
+ final List<Integer> groupingIndicatorOrdinals;
+ if (needsGroupingIndicators) {
+ groupingIndicatorOrdinals =
+ new ArrayList<>(Collections.nCopies(aggregateGroupingSets.size(),
-1));
+ } else {
+ groupingIndicatorOrdinals = ImmutableList.of();
+ }
+
+ int valueIndex = bottomGroupCount;
+ final List<AggregateCall> upperAggCalls = new ArrayList<>();
+ final List<List<Integer>> aggCallOrdinals = new ArrayList<>();
+ final List<AggregateCall> aggCalls = aggregate.getAggCallList();
+
+ // Build upper aggregates per declared grouping set. For each original
Review Comment:
line 3
##########
core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java:
##########
@@ -425,39 +426,118 @@ private static RelBuilder
convertSingletonDistinct(RelBuilder relBuilder,
return relBuilder;
}
+ /**
+ * Rewrite aggregates that use GROUPING SETS. The following SQL/plan example
+ * serves as the concrete blueprint, starting from the original statement and
+ * plan-before outputs and then rebuilding the plan-after tree from the
bottom
+ * (line 7) back to the top (line 1):
+ *
+ * <p>Original SQL:
+ * <pre>{@code
+ * SELECT deptno, COUNT(DISTINCT sal)
+ * FROM emp
+ * GROUP BY ROLLUP(deptno)
+ * }</pre>
+ *
+ * <p>Plan before rewrite:
+ * <pre>{@code
+ * LogicalAggregate(group=[{0}], groups=[[{0}, {}]], EXPR$1=[COUNT(DISTINCT
$1)])
+ * LogicalProject(DEPTNO=[$7], SAL=[$5])
+ * LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ * }</pre>
+ *
+ * <p>Plan after rewrite (lines referenced below):
+ * <pre>{@code
+ * 1 LogicalProject(DEPTNO=[$0],
+ * EXPR$1=[CAST(CASE(=($5, 0), $1, =($5, 1), $2,
+ * null:BIGINT)):BIGINT NOT NULL])
+ * 2 LogicalFilter(condition=[OR(
+ * AND(=($5, 0), >($3, 0)), =($5, 1))])
+ * 3 LogicalAggregate(group=[{0}], groups=[[{0}, {}]],
+ * EXPR$1_g0=[COUNT($1) FILTER $2],
+ * EXPR$1_g1=[COUNT($1) FILTER $4],
+ * $g_present_0=[COUNT() FILTER $3],
+ * $g_present_1=[COUNT() FILTER $5],
+ * $g_final=[GROUPING($0)])
+ * 4 LogicalProject(DEPTNO=[$0], SAL=[$1],
+ * $g_0=[=($2, 0)], $g_1=[=($2, 1)],
+ * $g_2=[=($2, 2)], $g_3=[=($2, 3)])
+ * 5 LogicalAggregate(group=[{0, 1}],
+ * groups=[[{0, 1}, {0}, {1}, {}]], $g=[GROUPING($0, $1)])
+ * 6 LogicalProject(DEPTNO=[$7], SAL=[$5])
+ * 7 LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ * }</pre>
+ *
+ * <p>The method performs the following actions:
+ * <ul>
+ * <li>Reuse the incoming scan and projection (lines 7 and 6) by pushing the
+ * original aggregate input onto the builder.</li>
+ * <li>Enumerate all grouping-set combinations and run the "bottom" aggregate
+ * over {@code fullGroupSet} to materialize line 5, including the internal
+ * {@code GROUPING()} value.</li>
+ * <li>Project the boolean selector columns that compare {@code GROUPING()}
+ * outputs to the required combinations, which surfaces line 4.</li>
+ * <li>Build the "upper" grouping-set aggregates with per-set FILTER clauses,
+ * reproducing line 3 and retaining presence counters / grouping ids.</li>
+ * <li>Assemble {@code keepConditions} so we can emit the filter of line 2
that
+ * drops internal-only rows.</li>
+ * <li>Produce the final projection (line 1) that routes each aggregate
result
+ * to the user-visible columns.</li>
+ * </ul>
+ */
private static void rewriteUsingGroupingSets(RelOptRuleCall call,
Aggregate aggregate) {
+ final ImmutableBitSet aggregateGroupSet = aggregate.getGroupSet();
+ final ImmutableList<ImmutableBitSet> aggregateGroupingSets =
aggregate.getGroupSets();
+
final Set<ImmutableBitSet> groupSetTreeSet =
new TreeSet<>(ImmutableBitSet.ORDERING);
- // GroupSet to distinct filter arg map,
- // filterArg will be -1 for non-distinct agg call.
- // Using `Set` here because it's possible that two agg calls
- // have different filterArgs but same groupSet.
+ // Map from a set of group keys -> which filter args (if any) contributed
+ // to that combination. Used to generate boolean marker columns later which
+ // indicate whether a bottom-row should be considered for a particular
+ // (grouping-set, filter) combination.
final Map<ImmutableBitSet, Set<Integer>> distinctFilterArgMap = new
HashMap<>();
+
+ // Enumerating every required grouping-set combination, including distinct
+ // args or filter columns relied on by the downstream projection.
Review Comment:
is "downstream" line 4? could add this info.
##########
core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java:
##########
@@ -520,43 +609,203 @@ private static void
rewriteUsingGroupingSets(RelOptRuleCall call,
relBuilder.project(nodes);
}
- int x = groupCount;
- final ImmutableBitSet groupSet = aggregate.getGroupSet();
- final List<AggregateCall> newCalls = new ArrayList<>();
- for (AggregateCall aggCall : aggregate.getAggCallList()) {
- final int newFilterArg;
- final List<Integer> newArgList;
- final SqlAggFunction aggregation;
+ // Compute the remapped top-group key and grouping sets. The top-group key
+ // selects which fields of the bottom result correspond to the original
+ // aggregate's group-by columns. Upper aggregates will group by this key.
+ final ImmutableBitSet topGroupKey = remap(fullGroupSet, aggregateGroupSet);
+ final ImmutableList<ImmutableBitSet> topGroupingSets =
+ remap(fullGroupSet, aggregate.getGroupSets());
+ final int topGroupCount = topGroupKey.cardinality();
+ final boolean needsGroupingIndicators = aggregate.getGroupType() !=
Group.SIMPLE;
+ final List<Integer> groupingIndicatorOrdinals;
+ if (needsGroupingIndicators) {
+ groupingIndicatorOrdinals =
+ new ArrayList<>(Collections.nCopies(aggregateGroupingSets.size(),
-1));
+ } else {
+ groupingIndicatorOrdinals = ImmutableList.of();
+ }
+
+ int valueIndex = bottomGroupCount;
+ final List<AggregateCall> upperAggCalls = new ArrayList<>();
+ final List<List<Integer>> aggCallOrdinals = new ArrayList<>();
+ final List<AggregateCall> aggCalls = aggregate.getAggCallList();
+
+ // Build upper aggregates per declared grouping set. For each original
+ // aggCall we create one upper agg per declared grouping set. The upper
+ // aggregate groups by {@code topGroupKey} and uses the boolean marker
+ // columns (placed at known ordinals) as the FILTER argument for the
+ // corresponding per-group aggregation. The list {@code aggCallOrdinals}
+ // records, for each original aggCall, the output field ordinals of the
+ // corresponding upper-aggregate results (one per grouping set).
+ for (AggregateCall aggCall : aggCalls) {
+ final List<Integer> ordinals = new ArrayList<>();
if (!aggCall.isDistinct()) {
- aggregation = SqlStdOperatorTable.MIN;
- newArgList = ImmutableIntList.of(x++);
- newFilterArg =
- requireNonNull(filters.get(Pair.of(groupSet, -1)),
- "filters.get(Pair.of(groupSet, -1))");
+ final int inputIndex = valueIndex++;
+ final List<Integer> args = ImmutableIntList.of(inputIndex);
+ for (int g = 0; g < aggregateGroupingSets.size(); g++) {
+ final ImmutableBitSet groupingSet = aggregateGroupingSets.get(g);
+ final int newFilterArg =
+ requireNonNull(filters.get(Pair.of(groupingSet, -1)),
+ () -> "filters.get(" + groupingSet + ", -1)");
+ final String upperAggName = upperAggCallName(aggCall, g);
+ final AggregateCall newCall =
+ AggregateCall.create(aggCall.getParserPosition(),
+ SqlStdOperatorTable.MIN, false, aggCall.isApproximate(),
+ aggCall.ignoreNulls(), aggCall.rexList, args, newFilterArg,
+ aggCall.distinctKeys, aggCall.collation,
aggregate.hasEmptyGroup(),
+ relBuilder.peek(), null, upperAggName);
+ upperAggCalls.add(newCall);
+ ordinals.add(topGroupCount + upperAggCalls.size() - 1);
+ }
} else {
- aggregation = aggCall.getAggregation();
- newArgList = remap(fullGroupSet, aggCall.getArgList());
- final ImmutableBitSet newGroupSet =
ImmutableBitSet.of(aggCall.getArgList())
- .setIf(aggCall.filterArg, aggCall.filterArg >= 0)
- .union(groupSet);
- newFilterArg =
- requireNonNull(filters.get(Pair.of(newGroupSet,
aggCall.filterArg)),
- "filters.get(of(newGroupSet, aggCall.filterArg))");
+ final List<Integer> newArgList = remap(fullGroupSet,
aggCall.getArgList());
+ for (int g = 0; g < aggregateGroupingSets.size(); g++) {
+ final ImmutableBitSet groupingSet = aggregateGroupingSets.get(g);
+ final ImmutableBitSet newGroupSet =
ImmutableBitSet.of(aggCall.getArgList())
+ .setIf(aggCall.filterArg, aggCall.filterArg >= 0)
+ .union(groupingSet);
+ final int newFilterArg =
+ requireNonNull(filters.get(Pair.of(newGroupSet,
aggCall.filterArg)),
+ () -> "filters.get(" + newGroupSet + ", " +
aggCall.filterArg + ")");
+ final String upperAggName = upperAggCallName(aggCall, g);
+ final AggregateCall newCall =
+ AggregateCall.create(aggCall.getParserPosition(),
aggCall.getAggregation(), false,
+ aggCall.isApproximate(), aggCall.ignoreNulls(),
+ aggCall.rexList, newArgList, newFilterArg,
+ aggCall.distinctKeys, aggCall.collation,
+ aggregate.hasEmptyGroup(), relBuilder.peek(), null,
upperAggName);
+ upperAggCalls.add(newCall);
+ ordinals.add(topGroupCount + upperAggCalls.size() - 1);
+ }
}
- final AggregateCall newCall =
- AggregateCall.create(aggCall.getParserPosition(), aggregation, false,
- aggCall.isApproximate(), aggCall.ignoreNulls(),
- aggCall.rexList, newArgList, newFilterArg,
- aggCall.distinctKeys, aggCall.collation,
- aggregate.hasEmptyGroup(), relBuilder.peek(), null,
aggCall.name);
- newCalls.add(newCall);
+ aggCallOrdinals.add(ordinals);
}
+ // If grouping indicators are needed (ROLLUP/CUBE/GROUPING SETS with more
+ // than one grouping set), add COUNT(...) presence calls which are later
+ // used to determine whether a grouping set produced any rows. These are
+ // used to implement semantics where empty grouping sets still must
+ // produce a result.
+ if (needsGroupingIndicators) {
+ for (int g = 0; g < aggregateGroupingSets.size(); g++) {
+ final ImmutableBitSet groupingSet = aggregateGroupingSets.get(g);
+ final Integer filterField = filters.get(Pair.of(groupingSet, -1));
+ if (filterField == null) {
+ continue;
+ }
+ final AggregateCall presenceCall =
+ AggregateCall.create(SqlStdOperatorTable.COUNT, false, false,
false,
+ ImmutableList.of(), ImmutableIntList.of(), filterField, null,
+ RelCollations.EMPTY, aggregate.hasEmptyGroup(),
relBuilder.peek(), null,
+ "$g_present_" + g);
+ upperAggCalls.add(presenceCall);
+ groupingIndicatorOrdinals.set(g, topGroupCount + upperAggCalls.size()
- 1);
+ }
+ }
+
+ // If there are multiple declared grouping sets then we need a
+ // GROUPING() value in the upper aggregate so we can later route results
+ // to the correct output using CASE expressions. Compute and append that
+ // grouping-call if required.
Review Comment:
even more line 3 items
##########
core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java:
##########
@@ -520,43 +609,203 @@ private static void
rewriteUsingGroupingSets(RelOptRuleCall call,
relBuilder.project(nodes);
}
- int x = groupCount;
- final ImmutableBitSet groupSet = aggregate.getGroupSet();
- final List<AggregateCall> newCalls = new ArrayList<>();
- for (AggregateCall aggCall : aggregate.getAggCallList()) {
- final int newFilterArg;
- final List<Integer> newArgList;
- final SqlAggFunction aggregation;
+ // Compute the remapped top-group key and grouping sets. The top-group key
+ // selects which fields of the bottom result correspond to the original
+ // aggregate's group-by columns. Upper aggregates will group by this key.
+ final ImmutableBitSet topGroupKey = remap(fullGroupSet, aggregateGroupSet);
+ final ImmutableList<ImmutableBitSet> topGroupingSets =
+ remap(fullGroupSet, aggregate.getGroupSets());
+ final int topGroupCount = topGroupKey.cardinality();
+ final boolean needsGroupingIndicators = aggregate.getGroupType() !=
Group.SIMPLE;
+ final List<Integer> groupingIndicatorOrdinals;
+ if (needsGroupingIndicators) {
+ groupingIndicatorOrdinals =
+ new ArrayList<>(Collections.nCopies(aggregateGroupingSets.size(),
-1));
+ } else {
+ groupingIndicatorOrdinals = ImmutableList.of();
+ }
+
+ int valueIndex = bottomGroupCount;
+ final List<AggregateCall> upperAggCalls = new ArrayList<>();
+ final List<List<Integer>> aggCallOrdinals = new ArrayList<>();
+ final List<AggregateCall> aggCalls = aggregate.getAggCallList();
+
+ // Build upper aggregates per declared grouping set. For each original
+ // aggCall we create one upper agg per declared grouping set. The upper
+ // aggregate groups by {@code topGroupKey} and uses the boolean marker
+ // columns (placed at known ordinals) as the FILTER argument for the
+ // corresponding per-group aggregation. The list {@code aggCallOrdinals}
+ // records, for each original aggCall, the output field ordinals of the
+ // corresponding upper-aggregate results (one per grouping set).
+ for (AggregateCall aggCall : aggCalls) {
+ final List<Integer> ordinals = new ArrayList<>();
if (!aggCall.isDistinct()) {
- aggregation = SqlStdOperatorTable.MIN;
- newArgList = ImmutableIntList.of(x++);
- newFilterArg =
- requireNonNull(filters.get(Pair.of(groupSet, -1)),
- "filters.get(Pair.of(groupSet, -1))");
+ final int inputIndex = valueIndex++;
+ final List<Integer> args = ImmutableIntList.of(inputIndex);
+ for (int g = 0; g < aggregateGroupingSets.size(); g++) {
+ final ImmutableBitSet groupingSet = aggregateGroupingSets.get(g);
+ final int newFilterArg =
+ requireNonNull(filters.get(Pair.of(groupingSet, -1)),
+ () -> "filters.get(" + groupingSet + ", -1)");
+ final String upperAggName = upperAggCallName(aggCall, g);
+ final AggregateCall newCall =
+ AggregateCall.create(aggCall.getParserPosition(),
+ SqlStdOperatorTable.MIN, false, aggCall.isApproximate(),
+ aggCall.ignoreNulls(), aggCall.rexList, args, newFilterArg,
+ aggCall.distinctKeys, aggCall.collation,
aggregate.hasEmptyGroup(),
+ relBuilder.peek(), null, upperAggName);
+ upperAggCalls.add(newCall);
+ ordinals.add(topGroupCount + upperAggCalls.size() - 1);
+ }
} else {
- aggregation = aggCall.getAggregation();
- newArgList = remap(fullGroupSet, aggCall.getArgList());
- final ImmutableBitSet newGroupSet =
ImmutableBitSet.of(aggCall.getArgList())
- .setIf(aggCall.filterArg, aggCall.filterArg >= 0)
- .union(groupSet);
- newFilterArg =
- requireNonNull(filters.get(Pair.of(newGroupSet,
aggCall.filterArg)),
- "filters.get(of(newGroupSet, aggCall.filterArg))");
+ final List<Integer> newArgList = remap(fullGroupSet,
aggCall.getArgList());
+ for (int g = 0; g < aggregateGroupingSets.size(); g++) {
+ final ImmutableBitSet groupingSet = aggregateGroupingSets.get(g);
+ final ImmutableBitSet newGroupSet =
ImmutableBitSet.of(aggCall.getArgList())
+ .setIf(aggCall.filterArg, aggCall.filterArg >= 0)
+ .union(groupingSet);
+ final int newFilterArg =
+ requireNonNull(filters.get(Pair.of(newGroupSet,
aggCall.filterArg)),
+ () -> "filters.get(" + newGroupSet + ", " +
aggCall.filterArg + ")");
+ final String upperAggName = upperAggCallName(aggCall, g);
+ final AggregateCall newCall =
+ AggregateCall.create(aggCall.getParserPosition(),
aggCall.getAggregation(), false,
+ aggCall.isApproximate(), aggCall.ignoreNulls(),
+ aggCall.rexList, newArgList, newFilterArg,
+ aggCall.distinctKeys, aggCall.collation,
+ aggregate.hasEmptyGroup(), relBuilder.peek(), null,
upperAggName);
+ upperAggCalls.add(newCall);
+ ordinals.add(topGroupCount + upperAggCalls.size() - 1);
+ }
}
- final AggregateCall newCall =
- AggregateCall.create(aggCall.getParserPosition(), aggregation, false,
- aggCall.isApproximate(), aggCall.ignoreNulls(),
- aggCall.rexList, newArgList, newFilterArg,
- aggCall.distinctKeys, aggCall.collation,
- aggregate.hasEmptyGroup(), relBuilder.peek(), null,
aggCall.name);
- newCalls.add(newCall);
+ aggCallOrdinals.add(ordinals);
}
+ // If grouping indicators are needed (ROLLUP/CUBE/GROUPING SETS with more
+ // than one grouping set), add COUNT(...) presence calls which are later
+ // used to determine whether a grouping set produced any rows. These are
+ // used to implement semantics where empty grouping sets still must
+ // produce a result.
+ if (needsGroupingIndicators) {
+ for (int g = 0; g < aggregateGroupingSets.size(); g++) {
+ final ImmutableBitSet groupingSet = aggregateGroupingSets.get(g);
+ final Integer filterField = filters.get(Pair.of(groupingSet, -1));
+ if (filterField == null) {
+ continue;
+ }
+ final AggregateCall presenceCall =
+ AggregateCall.create(SqlStdOperatorTable.COUNT, false, false,
false,
+ ImmutableList.of(), ImmutableIntList.of(), filterField, null,
+ RelCollations.EMPTY, aggregate.hasEmptyGroup(),
relBuilder.peek(), null,
+ "$g_present_" + g);
+ upperAggCalls.add(presenceCall);
+ groupingIndicatorOrdinals.set(g, topGroupCount + upperAggCalls.size()
- 1);
+ }
+ }
+
+ // If there are multiple declared grouping sets then we need a
+ // GROUPING() value in the upper aggregate so we can later route results
+ // to the correct output using CASE expressions. Compute and append that
+ // grouping-call if required.
+ final boolean needsGroupingId = aggregateGroupingSets.size() > 1;
+ final int groupingIdOrdinal;
+ if (needsGroupingId) {
+ final ImmutableBitSet remappedGroupSet = remap(fullGroupSet,
aggregateGroupSet);
+ final AggregateCall groupingCall =
+ AggregateCall.create(SqlStdOperatorTable.GROUPING, false, false,
false,
+ ImmutableList.of(),
ImmutableIntList.copyOf(remappedGroupSet.asList()), -1, null,
+ RelCollations.EMPTY, aggregate.hasEmptyGroup(),
relBuilder.peek(), null, "$g_final");
+ upperAggCalls.add(groupingCall);
+ groupingIdOrdinal = topGroupCount + upperAggCalls.size() - 1;
+ } else {
+ groupingIdOrdinal = -1;
+ }
+
+ // Line 3: build the upper aggregate layer, grouping by the original keys
+ // and applying FILTERs (and presence/grouping columns) per declared set.
relBuilder.aggregate(
- relBuilder.groupKey(
- remap(fullGroupSet, groupSet),
- remap(fullGroupSet, aggregate.getGroupSets())),
- newCalls);
+ relBuilder.groupKey(topGroupKey, topGroupingSets),
+ upperAggCalls);
+
+ final ImmutableList<Integer> groupingIdColumns =
+ ImmutableList.copyOf(Util.range(topGroupCount));
+ final RexNode groupingIdRef = needsGroupingId ?
relBuilder.field(groupingIdOrdinal) : null;
+
+ if (needsGroupingIndicators) {
+ final List<RexNode> keepConditions = new ArrayList<>();
+ for (int g = 0; g < aggregateGroupingSets.size(); g++) {
+ final int indicatorOrdinal = groupingIndicatorOrdinals.get(g);
+ if (indicatorOrdinal < 0) {
+ continue;
+ }
+ final ImmutableBitSet groupingSet = aggregateGroupingSets.get(g);
+ final RexNode requiredRows;
+ if (groupingSet.isEmpty()) {
+ // Empty grouping sets must still produce a row even if the input is
+ // empty, so do not require any contributing tuples.
+ requiredRows = relBuilder.literal(true);
+ } else {
+ requiredRows =
+ relBuilder.greaterThan(relBuilder.field(indicatorOrdinal),
+ relBuilder.literal(0));
+ }
+
+ final RexNode groupingMatches;
+ if (needsGroupingId) {
+ final long groupingValue =
+ groupValue(groupingIdColumns, remap(aggregateGroupSet,
groupingSet));
+ groupingMatches =
+ relBuilder.equals(requireNonNull(groupingIdRef, "groupingIdRef"),
+ relBuilder.literal(groupingValue));
+ } else {
+ groupingMatches = relBuilder.literal(true);
+ }
+ keepConditions.add(relBuilder.and(groupingMatches, requiredRows));
+ }
+
+ // Line 2: filter away rows produced solely for internal combinations.
+ if (!keepConditions.isEmpty()) {
+ RexNode condition = keepConditions.get(0);
+ for (int i = 1; i < keepConditions.size(); i++) {
+ condition = relBuilder.or(condition, keepConditions.get(i));
+ }
+ relBuilder.filter(condition);
+ }
+ }
+
+ final List<RexNode> projects = new ArrayList<>();
Review Comment:
assemble here the projections for line 1
##########
core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java:
##########
@@ -520,43 +609,203 @@ private static void
rewriteUsingGroupingSets(RelOptRuleCall call,
relBuilder.project(nodes);
}
- int x = groupCount;
- final ImmutableBitSet groupSet = aggregate.getGroupSet();
- final List<AggregateCall> newCalls = new ArrayList<>();
- for (AggregateCall aggCall : aggregate.getAggCallList()) {
- final int newFilterArg;
- final List<Integer> newArgList;
- final SqlAggFunction aggregation;
+ // Compute the remapped top-group key and grouping sets. The top-group key
+ // selects which fields of the bottom result correspond to the original
+ // aggregate's group-by columns. Upper aggregates will group by this key.
+ final ImmutableBitSet topGroupKey = remap(fullGroupSet, aggregateGroupSet);
+ final ImmutableList<ImmutableBitSet> topGroupingSets =
+ remap(fullGroupSet, aggregate.getGroupSets());
+ final int topGroupCount = topGroupKey.cardinality();
+ final boolean needsGroupingIndicators = aggregate.getGroupType() !=
Group.SIMPLE;
+ final List<Integer> groupingIndicatorOrdinals;
+ if (needsGroupingIndicators) {
+ groupingIndicatorOrdinals =
+ new ArrayList<>(Collections.nCopies(aggregateGroupingSets.size(),
-1));
+ } else {
+ groupingIndicatorOrdinals = ImmutableList.of();
+ }
+
+ int valueIndex = bottomGroupCount;
+ final List<AggregateCall> upperAggCalls = new ArrayList<>();
+ final List<List<Integer>> aggCallOrdinals = new ArrayList<>();
+ final List<AggregateCall> aggCalls = aggregate.getAggCallList();
+
+ // Build upper aggregates per declared grouping set. For each original
+ // aggCall we create one upper agg per declared grouping set. The upper
+ // aggregate groups by {@code topGroupKey} and uses the boolean marker
+ // columns (placed at known ordinals) as the FILTER argument for the
+ // corresponding per-group aggregation. The list {@code aggCallOrdinals}
+ // records, for each original aggCall, the output field ordinals of the
+ // corresponding upper-aggregate results (one per grouping set).
+ for (AggregateCall aggCall : aggCalls) {
+ final List<Integer> ordinals = new ArrayList<>();
if (!aggCall.isDistinct()) {
- aggregation = SqlStdOperatorTable.MIN;
- newArgList = ImmutableIntList.of(x++);
- newFilterArg =
- requireNonNull(filters.get(Pair.of(groupSet, -1)),
- "filters.get(Pair.of(groupSet, -1))");
+ final int inputIndex = valueIndex++;
+ final List<Integer> args = ImmutableIntList.of(inputIndex);
+ for (int g = 0; g < aggregateGroupingSets.size(); g++) {
+ final ImmutableBitSet groupingSet = aggregateGroupingSets.get(g);
+ final int newFilterArg =
+ requireNonNull(filters.get(Pair.of(groupingSet, -1)),
+ () -> "filters.get(" + groupingSet + ", -1)");
+ final String upperAggName = upperAggCallName(aggCall, g);
+ final AggregateCall newCall =
+ AggregateCall.create(aggCall.getParserPosition(),
Review Comment:
can you add an explanation about what MIN is doing here?
##########
core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java:
##########
@@ -520,43 +609,203 @@ private static void
rewriteUsingGroupingSets(RelOptRuleCall call,
relBuilder.project(nodes);
}
- int x = groupCount;
- final ImmutableBitSet groupSet = aggregate.getGroupSet();
- final List<AggregateCall> newCalls = new ArrayList<>();
- for (AggregateCall aggCall : aggregate.getAggCallList()) {
- final int newFilterArg;
- final List<Integer> newArgList;
- final SqlAggFunction aggregation;
+ // Compute the remapped top-group key and grouping sets. The top-group key
+ // selects which fields of the bottom result correspond to the original
+ // aggregate's group-by columns. Upper aggregates will group by this key.
+ final ImmutableBitSet topGroupKey = remap(fullGroupSet, aggregateGroupSet);
+ final ImmutableList<ImmutableBitSet> topGroupingSets =
+ remap(fullGroupSet, aggregate.getGroupSets());
+ final int topGroupCount = topGroupKey.cardinality();
+ final boolean needsGroupingIndicators = aggregate.getGroupType() !=
Group.SIMPLE;
+ final List<Integer> groupingIndicatorOrdinals;
+ if (needsGroupingIndicators) {
+ groupingIndicatorOrdinals =
+ new ArrayList<>(Collections.nCopies(aggregateGroupingSets.size(),
-1));
+ } else {
+ groupingIndicatorOrdinals = ImmutableList.of();
+ }
+
+ int valueIndex = bottomGroupCount;
+ final List<AggregateCall> upperAggCalls = new ArrayList<>();
+ final List<List<Integer>> aggCallOrdinals = new ArrayList<>();
+ final List<AggregateCall> aggCalls = aggregate.getAggCallList();
+
+ // Build upper aggregates per declared grouping set. For each original
+ // aggCall we create one upper agg per declared grouping set. The upper
+ // aggregate groups by {@code topGroupKey} and uses the boolean marker
+ // columns (placed at known ordinals) as the FILTER argument for the
+ // corresponding per-group aggregation. The list {@code aggCallOrdinals}
+ // records, for each original aggCall, the output field ordinals of the
+ // corresponding upper-aggregate results (one per grouping set).
+ for (AggregateCall aggCall : aggCalls) {
+ final List<Integer> ordinals = new ArrayList<>();
if (!aggCall.isDistinct()) {
- aggregation = SqlStdOperatorTable.MIN;
- newArgList = ImmutableIntList.of(x++);
- newFilterArg =
- requireNonNull(filters.get(Pair.of(groupSet, -1)),
- "filters.get(Pair.of(groupSet, -1))");
+ final int inputIndex = valueIndex++;
+ final List<Integer> args = ImmutableIntList.of(inputIndex);
+ for (int g = 0; g < aggregateGroupingSets.size(); g++) {
+ final ImmutableBitSet groupingSet = aggregateGroupingSets.get(g);
+ final int newFilterArg =
+ requireNonNull(filters.get(Pair.of(groupingSet, -1)),
+ () -> "filters.get(" + groupingSet + ", -1)");
+ final String upperAggName = upperAggCallName(aggCall, g);
+ final AggregateCall newCall =
+ AggregateCall.create(aggCall.getParserPosition(),
+ SqlStdOperatorTable.MIN, false, aggCall.isApproximate(),
+ aggCall.ignoreNulls(), aggCall.rexList, args, newFilterArg,
+ aggCall.distinctKeys, aggCall.collation,
aggregate.hasEmptyGroup(),
+ relBuilder.peek(), null, upperAggName);
+ upperAggCalls.add(newCall);
+ ordinals.add(topGroupCount + upperAggCalls.size() - 1);
+ }
} else {
- aggregation = aggCall.getAggregation();
- newArgList = remap(fullGroupSet, aggCall.getArgList());
- final ImmutableBitSet newGroupSet =
ImmutableBitSet.of(aggCall.getArgList())
- .setIf(aggCall.filterArg, aggCall.filterArg >= 0)
- .union(groupSet);
- newFilterArg =
- requireNonNull(filters.get(Pair.of(newGroupSet,
aggCall.filterArg)),
- "filters.get(of(newGroupSet, aggCall.filterArg))");
+ final List<Integer> newArgList = remap(fullGroupSet,
aggCall.getArgList());
+ for (int g = 0; g < aggregateGroupingSets.size(); g++) {
+ final ImmutableBitSet groupingSet = aggregateGroupingSets.get(g);
+ final ImmutableBitSet newGroupSet =
ImmutableBitSet.of(aggCall.getArgList())
+ .setIf(aggCall.filterArg, aggCall.filterArg >= 0)
+ .union(groupingSet);
+ final int newFilterArg =
+ requireNonNull(filters.get(Pair.of(newGroupSet,
aggCall.filterArg)),
+ () -> "filters.get(" + newGroupSet + ", " +
aggCall.filterArg + ")");
+ final String upperAggName = upperAggCallName(aggCall, g);
+ final AggregateCall newCall =
+ AggregateCall.create(aggCall.getParserPosition(),
aggCall.getAggregation(), false,
+ aggCall.isApproximate(), aggCall.ignoreNulls(),
+ aggCall.rexList, newArgList, newFilterArg,
+ aggCall.distinctKeys, aggCall.collation,
+ aggregate.hasEmptyGroup(), relBuilder.peek(), null,
upperAggName);
+ upperAggCalls.add(newCall);
+ ordinals.add(topGroupCount + upperAggCalls.size() - 1);
+ }
}
- final AggregateCall newCall =
- AggregateCall.create(aggCall.getParserPosition(), aggregation, false,
- aggCall.isApproximate(), aggCall.ignoreNulls(),
- aggCall.rexList, newArgList, newFilterArg,
- aggCall.distinctKeys, aggCall.collation,
- aggregate.hasEmptyGroup(), relBuilder.peek(), null,
aggCall.name);
- newCalls.add(newCall);
+ aggCallOrdinals.add(ordinals);
}
+ // If grouping indicators are needed (ROLLUP/CUBE/GROUPING SETS with more
+ // than one grouping set), add COUNT(...) presence calls which are later
+ // used to determine whether a grouping set produced any rows. These are
+ // used to implement semantics where empty grouping sets still must
+ // produce a result.
+ if (needsGroupingIndicators) {
+ for (int g = 0; g < aggregateGroupingSets.size(); g++) {
+ final ImmutableBitSet groupingSet = aggregateGroupingSets.get(g);
+ final Integer filterField = filters.get(Pair.of(groupingSet, -1));
+ if (filterField == null) {
+ continue;
+ }
+ final AggregateCall presenceCall =
+ AggregateCall.create(SqlStdOperatorTable.COUNT, false, false,
false,
+ ImmutableList.of(), ImmutableIntList.of(), filterField, null,
+ RelCollations.EMPTY, aggregate.hasEmptyGroup(),
relBuilder.peek(), null,
+ "$g_present_" + g);
+ upperAggCalls.add(presenceCall);
+ groupingIndicatorOrdinals.set(g, topGroupCount + upperAggCalls.size()
- 1);
+ }
+ }
+
+ // If there are multiple declared grouping sets then we need a
+ // GROUPING() value in the upper aggregate so we can later route results
+ // to the correct output using CASE expressions. Compute and append that
+ // grouping-call if required.
+ final boolean needsGroupingId = aggregateGroupingSets.size() > 1;
+ final int groupingIdOrdinal;
+ if (needsGroupingId) {
+ final ImmutableBitSet remappedGroupSet = remap(fullGroupSet,
aggregateGroupSet);
+ final AggregateCall groupingCall =
+ AggregateCall.create(SqlStdOperatorTable.GROUPING, false, false,
false,
+ ImmutableList.of(),
ImmutableIntList.copyOf(remappedGroupSet.asList()), -1, null,
+ RelCollations.EMPTY, aggregate.hasEmptyGroup(),
relBuilder.peek(), null, "$g_final");
+ upperAggCalls.add(groupingCall);
+ groupingIdOrdinal = topGroupCount + upperAggCalls.size() - 1;
+ } else {
+ groupingIdOrdinal = -1;
+ }
+
+ // Line 3: build the upper aggregate layer, grouping by the original keys
Review Comment:
finally build line 3
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]