DonnyZone commented on a change in pull request #1904: [CALCITE-3893] 
[CALCITE-3895] SQL with GROUP_ID may generate wrong plan
URL: https://github.com/apache/calcite/pull/1904#discussion_r406609562
 
 

 ##########
 File path: core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java
 ##########
 @@ -3144,6 +3190,105 @@ private RelNode rewriteAggregateWithGroupId(Blackboard 
bb,
     return LogicalUnion.create(projects, true);
   }
 
+  /**
+   * Rewrite {@code AggregateCall} with grouping function. When a group set 
members removed,
+   * remove the function argument accordingly. If there is no argument left, 
return null.
+   * @param aggregateCall AggregateCall with grouping function
+   * @param originGroupSet original group set
+   * @param newGroupSet new group set
+   * @param aggCallOrdinal ordinal of this AggregateCall
+   * @param groupingValues grouping values computed when no argument left
+   * @param groupingArgOrdinals ordinals of the grouping arguments dropped
+   * @return AggregateCall after rewrite
+   */
+  private AggregateCall rewriteGrouping(final AggregateCall aggregateCall,
+      final ImmutableBitSet originGroupSet, final ImmutableBitSet newGroupSet,
+      final int aggCallOrdinal, final Map<Integer, Long> groupingValues,
+      final Map<Integer, ImmutableBitSet> groupingArgOrdinals) {
+    final List<Integer> originArgs = aggregateCall.getArgList();
+    final List<Integer> newArgs = new ArrayList<>();
+    final List<Integer> ordinals = new ArrayList<>();
+    final List<Integer> newGroupList = newGroupSet.asList();
+    for (int ordinal = 0; ordinal < originArgs.size(); ordinal++) {
+      int arg = originArgs.get(ordinal);
+      if (newGroupList.contains(arg)) {
+        newArgs.add(arg);
+      } else {
+        ordinals.add(ordinal);
+      }
+    }
+    if (newArgs.isEmpty()) {
+      long groupingValue = 1L << originGroupSet.asList().size() - 1;
+      groupingValues.put(aggCallOrdinal, groupingValue);
+      return null;
+    }
+    final ImmutableBitSet argsDropped = ImmutableBitSet.of(ordinals);
+    groupingArgOrdinals.put(aggCallOrdinal, argsDropped);
+    final int originFilterArg = aggregateCall.filterArg;
+    final int filterArg = newGroupList.contains(originFilterArg) ? 
originFilterArg : -1;
+    return AggregateCall.create(aggregateCall.getAggregation(), 
aggregateCall.isDistinct(),
+        aggregateCall.isApproximate(), aggregateCall.ignoreNulls(), newArgs,
+        filterArg, aggregateCall.collation, aggregateCall.type, 
aggregateCall.name);
+  }
+
+  /**
+   * Re-compute the initial value from grouping function (with partial 
arguments).
+   *
+   * <p>For example, assume the initial function is: {@code grouping(a, b, c)},
+   * when the second argument is always null (i.e., {@code b = null}), we can
+   * get the result of {@code grouping(a, b, c)} from {@code grouping(a, c)}.
+   * That is, {@code value = grouping(a, c), argOrdinals = {1}, valueLen = 2}.
+   * </p>
+   *
+   * @param groupingValue the input grouping value
+   * @param argOrdinals ordinals of grouping args dropped
+   * @param valueLen length of the input number in binary
+   * @return result of the initial grouping value
+   */
+  private RexNode reComputeGrouping(final RexNode groupingValue,
+      final ImmutableBitSet argOrdinals, final int valueLen) {
+    int resultLen = valueLen;
+    RexNode result = groupingValue;
+    for (int ordinal: argOrdinals) {
+      result = insertBitInSpecificPosition(result, ordinal, resultLen);
+      resultLen++;
+    }
+    return result;
+  }
+
+  /**
+   * Insert a bit '1' into the value at specific position
+   * For example, when input = 7 (i.e., '111'), ordinal = 2, len = 3,
+   * the result is 15 (i.e., '1111')
+   * @param input the input number
+   * @param ordinal position
+   * @param len length of the input number in binary
+   * @return number after inserting bit
+   */
+  private RexNode insertBitInSpecificPosition(final RexNode input,
+       final int ordinal, final int len) {
+    assert ordinal >= 0 && ordinal <= len;
+    final int rightLen = len - ordinal;
+    final long v = (long) Math.pow(2, rightLen);
+    final RexNode vLiteral = 
rexBuilder.makeBigintLiteral(BigDecimal.valueOf(v));
+    // left part: num / v
+    final RexNode left = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, 
input, vLiteral);
+    // left value: left * v
+    final RexNode leftValue =
+        rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, left, vLiteral);
+    // right value: num - leftValue
+    final RexNode right =
+        rexBuilder.makeCall(SqlStdOperatorTable.MINUS, input, leftValue);
 
 Review comment:
   That's nice! From the Julian's comment in Jira, I also come up with another 
approach.
   We can compensate the removed groupset members and filter them again.
   ```
   select grouping(a, d, e, g) as g
   from t
   group by grouping sets(a, d, e, g)
   ```
   rewrites to
   ```
   select g2
   from (
     select (a, b, c, d, e, f, g, h) as g, grouping(b) as g1, grouping(c) as 
g1, grouping(f) as g3, grouping(h) as g4
     from t
     group by grouping sets(a, b, c, d, e, f, g, h)
   )
   where g1 = 1 and g2 = 1 and g3 = 1 and g4 = 1
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to