silundong commented on code in PR #4495:
URL: https://github.com/apache/calcite/pull/4495#discussion_r2290409739
##########
core/src/main/java/org/apache/calcite/tools/RelBuilder.java:
##########
@@ -2856,49 +2877,105 @@ private RelBuilder
rewriteAggregateWithDuplicateGroupSets(
}
}
- // AggregateCall list without GROUP_ID function
- final List<AggCall> aggregateCallsWithoutGroupId =
- new ArrayList<>(aggregateCalls);
- aggregateCallsWithoutGroupId.removeIf(RelBuilder::isGroupId);
+ // Create aggregate for each GROUP_ID value
+ for (Map.Entry<Integer, Set<ImmutableBitSet>> entry :
groupIdToGroupSets.entrySet()) {
+ // If n duplicates exist for a particular grouping, the {@code
GROUP_ID()}
+ // function produces values in the range 0 to n-1. For each value,
+ // we need to figure out the corresponding group sets.
+ //
+ // For example, "... GROUPING SETS (a, a, b, c, c, c, c)"
+ // (i) The max value of the GROUP_ID() function returns is 3
+ // (ii) GROUPING SETS (a, b, c) produces value 0,
+ // GROUPING SETS (a, c) produces value 1,
+ // GROUPING SETS (c) produces value 2
+ // GROUPING SETS (c) produces value 3
+ int groupId = entry.getKey();
+ Set<ImmutableBitSet> newGroupSets = entry.getValue();
+ rewriteGroupAggCalls(newGroupSets, groupSet, aggregateCalls, groupCount,
+ fieldNamesIfNoRewrite, mapping, input, groupId);
+ }
- // For each group id value, we first construct an Aggregate without
- // GROUP_ID() function call, and then create a Project node on top of it.
- // The Project adds literal value for group id in right position.
- final Frame frame = stack.pop();
- for (int groupId = 0; groupId <= maxGroupId; groupId++) {
- // Create the Aggregate node without GROUP_ID() call
- stack.push(frame);
- aggregate(groupKey(groupSet,
castNonNull(groupIdToGroupSets.get(groupId))),
- aggregateCallsWithoutGroupId);
-
- final List<RexNode> selectList = new ArrayList<>();
- final int groupExprLength = groupSet.cardinality();
- // Project fields in group by expressions
- for (int i = 0; i < groupExprLength; i++) {
- selectList.add(field(i));
- }
- // Project fields in aggregate calls
- int groupIdCount = 0;
- for (int i = 0; i < aggregateCalls.size(); i++) {
- if (isGroupId(aggregateCalls.get(i))) {
- selectList.add(
- getRexBuilder().makeExactLiteral(BigDecimal.valueOf(groupId),
- getTypeFactory().createSqlType(SqlTypeName.BIGINT)));
- groupIdCount++;
- } else {
- selectList.add(field(groupExprLength + i - groupIdCount));
- }
+ return union(true, groupIdToGroupSets.size());
+ }
+
+ /**
+ * Rewrite aggregate calls with special handling for GROUPING and GROUP_ID
functions,
+ * including NULL padding for missing grouping fields.
+ *
+ * @param groupSets The groupSets to use for this aggregate
+ * @param groupSet The groupSet to use for this aggregate
+ * @param aggregateCalls List of aggregate calls for this aggregate
+ * @param originalGroupCount The original groupSet size
+ * @param fieldNames Field names for the output row type
+ * @param mapping Field index mapping between input and output
+ * @param input The input relational expression
+ * @param groupId The GROUP_ID value to use for this aggregate
+ */
+ private void rewriteGroupAggCalls(Set<ImmutableBitSet> groupSets,
ImmutableBitSet groupSet,
+ List<AggregateCall> aggregateCalls, int originalGroupCount, List<String>
fieldNames,
+ Mappings.TargetMapping mapping, Frame input, int groupId) {
+ stack.push(input);
+
+ // specialFields records special values and their indexes.
+ // For example, GROUP_ID or GROUPING will be treated as
+ // numeric literals or NULL literals.
+ Map<Integer, RexNode> specialFields = new HashMap<>();
+ List<AggregateCall> aggCalls = new ArrayList<>();
+ for (int i = 0; i < aggregateCalls.size(); i++) {
+ AggregateCall aggCall = aggregateCalls.get(i);
+ switch (aggCall.getAggregation().getKind()) {
+ case GROUPING:
+ int grouping = calculateGroupingValue(groupSets.iterator().next(),
aggCall.getArgList());
+ specialFields.put(originalGroupCount + i,
+ getRexBuilder().makeLiteral(grouping, aggCall.getType(), true));
+ break;
+ case GROUP_ID:
+ specialFields.put(originalGroupCount + i,
+ getRexBuilder().makeLiteral(groupId, aggCall.getType()));
+ break;
+ default:
+ aggCalls.add(aggCall);
Review Comment:
Should AggregateCall be rebuild here? Its nullability may change from true
to false.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]