xiedeyantu commented on code in PR #4495:
URL: https://github.com/apache/calcite/pull/4495#discussion_r2299286617


##########
core/src/main/java/org/apache/calcite/tools/RelBuilder.java:
##########
@@ -2832,67 +2834,136 @@ private RelBuilder 
rewriteAggregateWithDuplicateGroupSets(
             aggregateCalls.stream().map(AggCallPlus::aggregateCall)
                 .collect(toImmutableList())).getFieldNames();
 
-    // If n duplicates exist for a particular grouping, the {@code GROUP_ID()}
-    // function produces values in the range 0 to n-1. For each value,
-    // we need to figure out the corresponding group sets.
-    //
-    // For example, "... GROUPING SETS (a, a, b, c, c, c, c)"
-    // (i) The max value of the GROUP_ID() function returns is 3
-    // (ii) GROUPING SETS (a, b, c) produces value 0,
-    //      GROUPING SETS (a, c) produces value 1,
-    //      GROUPING SETS (c) produces value 2
-    //      GROUPING SETS (c) produces value 3
+    final Frame input = stack.pop();
+
     final Map<Integer, Set<ImmutableBitSet>> groupIdToGroupSets = new 
HashMap<>();
-    int maxGroupId = 0;
     for (Multiset.Entry<ImmutableBitSet> entry : groupSets.entrySet()) {
       int groupId = entry.getCount() - 1;
-      if (groupId > maxGroupId) {
-        maxGroupId = groupId;
-      }
       for (int i = 0; i <= groupId; i++) {
         groupIdToGroupSets.computeIfAbsent(i,
             k -> Sets.newTreeSet(ImmutableBitSet.COMPARATOR))
             .add(entry.getElement());
       }
     }
 
-    // AggregateCall list without GROUP_ID function
-    final List<AggCall> aggregateCallsWithoutGroupId =
-        new ArrayList<>(aggregateCalls);
-    aggregateCallsWithoutGroupId.removeIf(RelBuilder::isGroupId);
-
-    // For each group id value, we first construct an Aggregate without
-    // GROUP_ID() function call, and then create a Project node on top of it.
-    // The Project adds literal value for group id in right position.
-    final Frame frame = stack.pop();
-    for (int groupId = 0; groupId <= maxGroupId; groupId++) {
-      // Create the Aggregate node without GROUP_ID() call
-      stack.push(frame);
-      aggregate(groupKey(groupSet, 
castNonNull(groupIdToGroupSets.get(groupId))),
-          aggregateCallsWithoutGroupId);
-
-      final List<RexNode> selectList = new ArrayList<>();
-      final int groupExprLength = groupSet.cardinality();
-      // Project fields in group by expressions
-      for (int i = 0; i < groupExprLength; i++) {
-        selectList.add(field(i));
+    // Create aggregate for each GROUP_ID value
+    for (Map.Entry<Integer, Set<ImmutableBitSet>> entry : 
groupIdToGroupSets.entrySet()) {
+      // If n duplicates exist for a particular grouping, the {@code 
GROUP_ID()}
+      // function produces values in the range 0 to n-1. For each value,
+      // we need to figure out the corresponding group sets.
+      //
+      // For example, "... GROUPING SETS (a, a, b, c, c, c, c)"
+      // (i) The max value of the GROUP_ID() function returns is 3
+      // (ii) GROUPING SETS (a, b, c) produces value 0,
+      //      GROUPING SETS (a, c) produces value 1,
+      //      GROUPING SETS (c) produces value 2
+      //      GROUPING SETS (c) produces value 3
+      int groupId = entry.getKey();
+      Set<ImmutableBitSet> newGroupSets = entry.getValue();
+      rewriteGroupAggCalls(newGroupSets, groupSet, aggregateCalls,
+          fieldNamesIfNoRewrite, input, groupId);
+    }
+
+    return union(true, groupIdToGroupSets.size());
+  }
+
+  private void rewriteGroupAggCalls(
+      Set<ImmutableBitSet> subGroupSets,
+      ImmutableBitSet oriGroupSet,
+      List<AggCallPlus> oriAggCalls,
+      List<String> fieldNames,
+      Frame input,
+      int groupId) {
+    stack.push(input);
+    List<AggCallPlus> subAggCalls = new ArrayList<>();
+    ImmutableBitSet subGroupSet = ImmutableBitSet.union(subGroupSets);
+    List<RexNode> subProjects = new ArrayList<>();
+
+    // 1. For GroupSet
+    RelDataType subAggregateGroupSetType =
+        Aggregate.deriveRowType(getTypeFactory(), peek().getRowType(), false,
+            subGroupSet, ImmutableList.copyOf(subGroupSets), 
ImmutableList.of());
+    for (int i = 0; i < oriGroupSet.cardinality(); i++) {
+      int groupKey = oriGroupSet.nth(i);
+      if (subGroupSet.get(groupKey)) {
+        subProjects.add(
+            RexInputRef.of(
+                subGroupSet.indexOf(groupKey),
+                subAggregateGroupSetType));
+      } else {
+        // If the groupKey is not in the GroupSet, use null as a placeholder.
+        
subProjects.add(getRexBuilder().makeNullLiteral(field(groupKey).getType()));
       }
-      // Project fields in aggregate calls
-      int groupIdCount = 0;
-      for (int i = 0; i < aggregateCalls.size(); i++) {
-        if (isGroupId(aggregateCalls.get(i))) {
-          selectList.add(
-              getRexBuilder().makeExactLiteral(BigDecimal.valueOf(groupId),
-                  getTypeFactory().createSqlType(SqlTypeName.BIGINT)));
-          groupIdCount++;
+    }
+
+    // 2. For AggregateCalls
+    int newGroupCount = subGroupSet.cardinality();
+    for (AggCallPlus oriAggCall : oriAggCalls) {
+      AggregateCall aggCall = oriAggCall.aggregateCall();
+      switch (oriAggCall.op().getKind()) {
+      case GROUPING:
+        if (!subGroupSet.contains(ImmutableBitSet.of(aggCall.getArgList()))) {
+          // If the parameters of the GROUPING function cannot be fully 
covered by the GroupSet,
+          // the GROUPING function needs to be split.
+          splitGrouping(subAggCalls, subProjects, aggCall, subGroupSet);
         } else {
-          selectList.add(field(groupExprLength + i - groupIdCount));
+          subProjects.add(
+              new RexInputRef(
+                  newGroupCount + subAggCalls.size(),
+                  aggCall.getType()));
+          subAggCalls.add(oriAggCall);
         }
+        break;
+      case GROUP_ID:
+        subProjects.add(
+            getRexBuilder().makeLiteral(
+                groupId,
+                aggCall.getType()));
+        break;
+      default:
+        subProjects.add(
+            new RexInputRef(
+                newGroupCount + subAggCalls.size(),
+                aggCall.getType()));
+        subAggCalls.add(oriAggCall);
+        break;
+      }
+    }
+
+    aggregate(groupKey(subGroupSet, subGroupSets), subAggCalls);
+    project(subProjects, fieldNames);
+  }
+
+  private void splitGrouping(
+      List<AggCallPlus> aggCalls,
+      List<RexNode> projects,
+      AggregateCall grouping,
+      ImmutableBitSet groupSet) {
+    List<RexNode> splitOperands = new ArrayList<>();
+    List<Integer> groupingArgs = grouping.getArgList();
+    RelDataType groupingType = 
getTypeFactory().createSqlType(SqlTypeName.BIGINT);
+    for (int i = 0; i < groupingArgs.size(); i++) {
+      int groupingArg = groupingArgs.get(i);
+      if (groupSet.get(groupingArg)) {
+        RexInputRef groupingRef =
+            new RexInputRef(groupSet.cardinality() + aggCalls.size(), 
groupingType);
+        splitOperands.add(
+            call(SqlStdOperatorTable.MULTIPLY,
+                literal(1 << groupingArgs.size() - 1 - i),

Review Comment:
   Hi @mihaibudiu , Can I start squashing the commits or are there any other 
changes I should make?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to