silundong commented on code in PR #4495:
URL: https://github.com/apache/calcite/pull/4495#discussion_r2290409739


##########
core/src/main/java/org/apache/calcite/tools/RelBuilder.java:
##########
@@ -2856,49 +2877,105 @@ private RelBuilder 
rewriteAggregateWithDuplicateGroupSets(
       }
     }
 
-    // AggregateCall list without GROUP_ID function
-    final List<AggCall> aggregateCallsWithoutGroupId =
-        new ArrayList<>(aggregateCalls);
-    aggregateCallsWithoutGroupId.removeIf(RelBuilder::isGroupId);
+    // Create aggregate for each GROUP_ID value
+    for (Map.Entry<Integer, Set<ImmutableBitSet>> entry : 
groupIdToGroupSets.entrySet()) {
+      // If n duplicates exist for a particular grouping, the {@code 
GROUP_ID()}
+      // function produces values in the range 0 to n-1. For each value,
+      // we need to figure out the corresponding group sets.
+      //
+      // For example, "... GROUPING SETS (a, a, b, c, c, c, c)"
+      // (i) The max value of the GROUP_ID() function returns is 3
+      // (ii) GROUPING SETS (a, b, c) produces value 0,
+      //      GROUPING SETS (a, c) produces value 1,
+      //      GROUPING SETS (c) produces value 2
+      //      GROUPING SETS (c) produces value 3
+      int groupId = entry.getKey();
+      Set<ImmutableBitSet> newGroupSets = entry.getValue();
+      rewriteGroupAggCalls(newGroupSets, groupSet, aggregateCalls, groupCount,
+          fieldNamesIfNoRewrite, mapping, input, groupId);
+    }
 
-    // For each group id value, we first construct an Aggregate without
-    // GROUP_ID() function call, and then create a Project node on top of it.
-    // The Project adds literal value for group id in right position.
-    final Frame frame = stack.pop();
-    for (int groupId = 0; groupId <= maxGroupId; groupId++) {
-      // Create the Aggregate node without GROUP_ID() call
-      stack.push(frame);
-      aggregate(groupKey(groupSet, 
castNonNull(groupIdToGroupSets.get(groupId))),
-          aggregateCallsWithoutGroupId);
-
-      final List<RexNode> selectList = new ArrayList<>();
-      final int groupExprLength = groupSet.cardinality();
-      // Project fields in group by expressions
-      for (int i = 0; i < groupExprLength; i++) {
-        selectList.add(field(i));
-      }
-      // Project fields in aggregate calls
-      int groupIdCount = 0;
-      for (int i = 0; i < aggregateCalls.size(); i++) {
-        if (isGroupId(aggregateCalls.get(i))) {
-          selectList.add(
-              getRexBuilder().makeExactLiteral(BigDecimal.valueOf(groupId),
-                  getTypeFactory().createSqlType(SqlTypeName.BIGINT)));
-          groupIdCount++;
-        } else {
-          selectList.add(field(groupExprLength + i - groupIdCount));
-        }
+    return union(true, groupIdToGroupSets.size());
+  }
+
+  /**
+   * Rewrite aggregate calls with special handling for GROUPING and GROUP_ID 
functions,
+   * including NULL padding for missing grouping fields.
+   *
+   * @param groupSets           The groupSets to use for this aggregate
+   * @param groupSet            The groupSet to use for this aggregate
+   * @param aggregateCalls      List of aggregate calls for this aggregate
+   * @param originalGroupCount  The original groupSet size
+   * @param fieldNames          Field names for the output row type
+   * @param mapping             Field index mapping between input and output
+   * @param input               The input relational expression
+   * @param groupId             The GROUP_ID value to use for this aggregate
+   */
+  private void rewriteGroupAggCalls(Set<ImmutableBitSet> groupSets, 
ImmutableBitSet groupSet,
+      List<AggregateCall> aggregateCalls, int originalGroupCount, List<String> 
fieldNames,
+      Mappings.TargetMapping mapping, Frame input, int groupId) {
+    stack.push(input);
+
+    // specialFields records special values and their indexes.
+    // For example, GROUP_ID or GROUPING will be treated as
+    // numeric literals or NULL literals.
+    Map<Integer, RexNode> specialFields = new HashMap<>();
+    List<AggregateCall> aggCalls = new ArrayList<>();
+    for (int i = 0; i < aggregateCalls.size(); i++) {
+      AggregateCall aggCall = aggregateCalls.get(i);
+      switch (aggCall.getAggregation().getKind()) {
+      case GROUPING:
+        int grouping = calculateGroupingValue(groupSets.iterator().next(), 
aggCall.getArgList());
+        specialFields.put(originalGroupCount + i,
+            getRexBuilder().makeLiteral(grouping, aggCall.getType(), true));
+        break;
+      case GROUP_ID:
+        specialFields.put(originalGroupCount + i,
+            getRexBuilder().makeLiteral(groupId, aggCall.getType()));
+        break;
+      default:
+        aggCalls.add(aggCall);

Review Comment:
   Should AggregateCall be rebuild here? Its nullability may change from true 
to false.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to