mihaibudiu commented on code in PR #4495:
URL: https://github.com/apache/calcite/pull/4495#discussion_r2299302797


##########
core/src/main/java/org/apache/calcite/tools/RelBuilder.java:
##########
@@ -2832,67 +2834,159 @@ private RelBuilder 
rewriteAggregateWithDuplicateGroupSets(
             aggregateCalls.stream().map(AggCallPlus::aggregateCall)
                 .collect(toImmutableList())).getFieldNames();
 
-    // If n duplicates exist for a particular grouping, the {@code GROUP_ID()}
-    // function produces values in the range 0 to n-1. For each value,
-    // we need to figure out the corresponding group sets.
-    //
-    // For example, "... GROUPING SETS (a, a, b, c, c, c, c)"
-    // (i) The max value of the GROUP_ID() function returns is 3
-    // (ii) GROUPING SETS (a, b, c) produces value 0,
-    //      GROUPING SETS (a, c) produces value 1,
-    //      GROUPING SETS (c) produces value 2
-    //      GROUPING SETS (c) produces value 3
+    final Frame input = stack.pop();
+
     final Map<Integer, Set<ImmutableBitSet>> groupIdToGroupSets = new 
HashMap<>();
-    int maxGroupId = 0;
     for (Multiset.Entry<ImmutableBitSet> entry : groupSets.entrySet()) {
       int groupId = entry.getCount() - 1;
-      if (groupId > maxGroupId) {
-        maxGroupId = groupId;
-      }
       for (int i = 0; i <= groupId; i++) {
         groupIdToGroupSets.computeIfAbsent(i,
             k -> Sets.newTreeSet(ImmutableBitSet.COMPARATOR))
             .add(entry.getElement());
       }
     }
 
-    // AggregateCall list without GROUP_ID function
-    final List<AggCall> aggregateCallsWithoutGroupId =
-        new ArrayList<>(aggregateCalls);
-    aggregateCallsWithoutGroupId.removeIf(RelBuilder::isGroupId);
-
-    // For each group id value, we first construct an Aggregate without
-    // GROUP_ID() function call, and then create a Project node on top of it.
-    // The Project adds literal value for group id in right position.
-    final Frame frame = stack.pop();
-    for (int groupId = 0; groupId <= maxGroupId; groupId++) {
-      // Create the Aggregate node without GROUP_ID() call
-      stack.push(frame);
-      aggregate(groupKey(groupSet, 
castNonNull(groupIdToGroupSets.get(groupId))),
-          aggregateCallsWithoutGroupId);
-
-      final List<RexNode> selectList = new ArrayList<>();
-      final int groupExprLength = groupSet.cardinality();
-      // Project fields in group by expressions
-      for (int i = 0; i < groupExprLength; i++) {
-        selectList.add(field(i));
+    // Create aggregate for each GROUP_ID value
+    for (Map.Entry<Integer, Set<ImmutableBitSet>> entry : 
groupIdToGroupSets.entrySet()) {
+      // If n duplicates exist for a particular grouping, the {@code 
GROUP_ID()}
+      // function produces values in the range 0 to n-1. For each value,
+      // we need to figure out the corresponding group sets.
+      //
+      // For example, "... GROUPING SETS (a, a, b, c, c, c, c)"
+      // (i) The max value of the GROUP_ID() function returns is 3
+      // (ii) GROUPING SETS (a, b, c) produces value 0,
+      //      GROUPING SETS (a, c) produces value 1,
+      //      GROUPING SETS (c) produces value 2
+      //      GROUPING SETS (c) produces value 3
+      int groupId = entry.getKey();
+      Set<ImmutableBitSet> newGroupSets = entry.getValue();
+      rewriteGroupAggCalls(newGroupSets, groupSet, aggregateCalls,
+          fieldNamesIfNoRewrite, input, groupId);
+    }
+
+    return union(true, groupIdToGroupSets.size());
+  }
+
+  private void rewriteGroupAggCalls(
+      Set<ImmutableBitSet> subGroupSets,
+      ImmutableBitSet oriGroupSet,
+      List<AggCallPlus> oriAggCalls,
+      List<String> fieldNames,
+      Frame input,
+      int groupId) {
+    stack.push(input);
+    List<AggCallPlus> subAggCalls = new ArrayList<>();
+    ImmutableBitSet subGroupSet = ImmutableBitSet.union(subGroupSets);
+    List<RexNode> subProjects = new ArrayList<>();
+
+    // 1. For GroupSet
+    RelDataType subAggregateGroupSetType =
+        Aggregate.deriveRowType(getTypeFactory(), peek().getRowType(), false,
+            subGroupSet, ImmutableList.copyOf(subGroupSets), 
ImmutableList.of());
+    for (int i = 0; i < oriGroupSet.cardinality(); i++) {
+      int groupKey = oriGroupSet.nth(i);
+      if (subGroupSet.get(groupKey)) {
+        subProjects.add(
+            RexInputRef.of(
+                subGroupSet.indexOf(groupKey),
+                subAggregateGroupSetType));
+      } else {
+        // If the groupKey is not in the GroupSet, use null as a placeholder.
+        
subProjects.add(getRexBuilder().makeNullLiteral(field(groupKey).getType()));
       }
-      // Project fields in aggregate calls
-      int groupIdCount = 0;
-      for (int i = 0; i < aggregateCalls.size(); i++) {
-        if (isGroupId(aggregateCalls.get(i))) {
-          selectList.add(
-              getRexBuilder().makeExactLiteral(BigDecimal.valueOf(groupId),
-                  getTypeFactory().createSqlType(SqlTypeName.BIGINT)));
-          groupIdCount++;
+    }
+
+    // 2. For AggregateCalls
+    int newGroupCount = subGroupSet.cardinality();
+    for (AggCallPlus oriAggCall : oriAggCalls) {
+      AggregateCall aggCall = oriAggCall.aggregateCall();
+      switch (oriAggCall.op().getKind()) {
+      case GROUPING:
+        if (!subGroupSet.contains(ImmutableBitSet.of(aggCall.getArgList()))) {
+          // If the parameters of the GROUPING function cannot be fully 
covered by the GroupSet,
+          // the GROUPING function needs to be split.
+          splitGrouping(subAggCalls, subProjects, aggCall, subGroupSet);
         } else {
-          selectList.add(field(groupExprLength + i - groupIdCount));
+          subProjects.add(
+              new RexInputRef(
+                  newGroupCount + subAggCalls.size(),
+                  aggCall.getType()));
+          subAggCalls.add(oriAggCall);
         }
+        break;
+      case GROUP_ID:
+        subProjects.add(
+            getRexBuilder().makeLiteral(
+                groupId,
+                aggCall.getType()));
+        break;
+      default:
+        subProjects.add(
+            new RexInputRef(
+                newGroupCount + subAggCalls.size(),
+                aggCall.getType()));
+        subAggCalls.add(oriAggCall);
+        break;
       }
-      project(selectList, fieldNamesIfNoRewrite);
     }
 
-    return union(true, maxGroupId + 1);
+    aggregate(groupKey(subGroupSet, subGroupSets), subAggCalls);
+    project(subProjects, fieldNames);
+  }
+
+  /**
+   * This method is used to expand the SQL GROUPING operator
+   * into a set of expressions. For example, it expands GROUPING(x, y, z)
+   * into 2^2 * GROUPING(x) + 2^1 * GROUPING(y) + 2^0 * GROUPING(z)
+   *
+   * <p>For example, in "GROUP BY GROUPING SETS ((a, b), (a), (), (b))":
+   * In the grouping set (a, b), GROUPING(a) = 0, GROUPING(b) = 0,
+   * GROUPING(a, b) = 2^1 * 0 + 2^0 * 0 => 0.
+   * In the grouping set (a), GROUPING(a) = 0, GROUPING(b) = 1,
+   * GROUPING(a, b) = 2^1 * 0 + 2^0 * 1 => 1.
+   * In the grouping set (), GROUPING(a) = 1, GROUPING(b) = 1,
+   * GROUPING(a, b) = 2^1 * 1 + 2^0 * 1 = 3.
+   * In the grouping set (b), GROUPING(a) = 1, GROUPING(b) = 0,
+   * GROUPING(a, b) = 2^1 * 1 + 2^0 * 0 = 2.
+   * Thus, GROUPING(a, b) produces 0, 1, 3, and 2 for each grouping set 
respectively.
+   */
+  private void splitGrouping(
+      List<AggCallPlus> aggCalls,
+      List<RexNode> projects,
+      AggregateCall grouping,
+      ImmutableBitSet groupSet) {
+    List<RexNode> splitOperands = new ArrayList<>();
+    List<Integer> groupingArgs = grouping.getArgList();
+    RelDataType groupingType = 
getTypeFactory().createSqlType(SqlTypeName.BIGINT);
+    for (int i = 0; i < groupingArgs.size(); i++) {
+      int groupingArg = groupingArgs.get(i);
+      if (groupSet.get(groupingArg)) {
+        RexInputRef groupingRef =
+            new RexInputRef(groupSet.cardinality() + aggCalls.size(), 
groupingType);
+        splitOperands.add(
+            call(SqlStdOperatorTable.MULTIPLY,
+                literal(safeShift(groupingArgs.size() - 1 - i)),
+                groupingRef));
+        aggCalls.add((AggCallPlus) aggregateCall(SqlStdOperatorTable.GROUPING, 
field(groupingArg)));
+      } else {
+        // Sets GROUPING function value to 1 for parameters not in GroupSet 
and calculates offset.
+        splitOperands.add(literal(safeShift(groupingArgs.size() - 1 - i)));
+      }
+    }
+
+    // Plus all expanded expressions (including GROUPING functions and 
constants).
+    RexNode plus = splitOperands.get(0);
+    for (int i = 1; i < splitOperands.size(); i++) {
+      plus = call(SqlStdOperatorTable.PLUS, plus, splitOperands.get(i));
+    }
+    projects.add(plus);
+  }
+
+  private static int safeShift(int shift) {
+    if (shift < 0 || shift >= Integer.SIZE) {
+      throw new IllegalArgumentException("Invalid shift amount: " + shift);

Review Comment:
   the problem with this error message is that it doesn't tell the user what 
the fault is.
   the underlying fault is that there are too many keys grouped on.
   I think that the error message should say that the number of keys cannot 
exceed 31 when using grouping_id
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to