vlsi commented on a change in pull request #1756: [CALCITE-1824] GROUP_ID 
returns wrong result
URL: https://github.com/apache/calcite/pull/1756#discussion_r366178638
 
 

 ##########
 File path: core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java
 ##########
 @@ -3039,6 +3040,113 @@ protected final void createAggImpl(
     }
   }
 
+  /**
+   * The {@code GROUP_ID()} function is used to distinguish duplicate groups.
+   * However, as Aggregate normalizes group sets (i.e., sorting, redundancy 
removal),
+   * this information is lost in RelNode. Therefore, it is impossible to
+   * implement the function in runtime.
+   *
+   * To fill this gap, an aggregation query that contains {@code GROUP_ID()} 
function
+   * will generally be rewritten into UNION when converting to RelNode.
+   *
+   * Also see the discussion in JIRA
+   * <a 
href="https://issues.apache.org/jira/browse/CALCITE-1824";>[CALCITE-1824]
+   * GROUP_ID returns wrong result</a>.
+   */
+  private RelNode rewriteAggregateWithGroupId(Blackboard bb,
+      AggregatingSelectScope.Resolved r, AggConverter converter) {
+    final List<AggregateCall> aggregateCalls = converter.getAggCalls();
+    final ImmutableBitSet groupSet = r.groupSet;
+    final Map<ImmutableBitSet, Integer> groupSetCount = r.groupSetCount;
+
+    final List<String> fieldNamesIfNoRewrite = createAggregate(bb, groupSet,
+        r.groupSets, aggregateCalls).getRowType().getFieldNames();
+
+    // For every GROUP_ID value, collect its group sets in map
+    // E.g., GROUPING SETS (a, a, b, c, c, c, c), the map will be
+    // {0 -> (a, b, c), 1 -> (a, c), 2 -> (c), 3 -> (c)},
+    // in which the max GROUP_ID() value is 3
+    final Map<Integer, Set<ImmutableBitSet>> groupIdToGroupSets = new 
HashMap<>();
+    int maxGroupId = 0;
+    for (Map.Entry<ImmutableBitSet, Integer> entry: groupSetCount.entrySet()) {
+      int groupId = entry.getValue() - 1;
+      if (groupId > maxGroupId) {
+        maxGroupId = groupId;
+      }
+      for (int i = 0; i <= groupId; i++) {
+        addGroupSet(i, entry.getKey(), groupIdToGroupSets);
+      }
+    }
+
+    // AggregateCalls without GROUP_ID
+    final List<AggregateCall> aggregateCallsWithoutGroupId = new ArrayList<>();
+    for (AggregateCall aggregateCall : aggregateCalls) {
+      if (aggregateCall.getAggregation().kind != SqlKind.GROUP_ID) {
+        aggregateCallsWithoutGroupId.add(aggregateCall);
+      }
+    }
+
+    final List<RelNode> projects = new ArrayList<>();
+    // For each group id, we first construct an Aggregate without GROUP_ID()
+    // function call, and then create a Project node on top of it.
+    // The Project adds literal value for group id in right position.
+    for (int groupId = 0; groupId <= maxGroupId; groupId++) {
+      // Create the Aggregate node without GROUP_ID() call
+      final ImmutableList<ImmutableBitSet> groupSets =
+          ImmutableList.copyOf(groupIdToGroupSets.get(groupId));
+      final RelNode aggregate = createAggregate(bb, groupSet,
+          groupSets, aggregateCallsWithoutGroupId);
+
+      // RexLiteral for each GROUP_ID, note the type should be BIGINT
+      final RelDataType groupIdType = 
typeFactory.createSqlType(SqlTypeName.BIGINT);
+      final RexNode groupIdLiteral = rexBuilder.makeExactLiteral(
+          BigDecimal.valueOf(groupId), groupIdType);
+
+      relBuilder.push(aggregate);
+      final List<String> aggregateFieldNames = 
aggregate.getRowType().getFieldNames();
+
+      final List<RexNode> selectList = new ArrayList<>();
+      final List<String> selectListNames = new ArrayList<>();
+      final int groupExprLength = r.groupExprList.size();
+      // Project fields from group by expressions
+      for (int i = 0; i < groupExprLength; i++) {
+        selectList.add(relBuilder.field(i));
+        selectListNames.add(aggregateFieldNames.get(i));
+      }
+      // Project fields from aggregate calls
+      int groupIdCount = 0;
+      for (int i = 0; i < aggregateCalls.size(); i++) {
+        if (aggregateCalls.get(i).getAggregation().kind == SqlKind.GROUP_ID) {
+          selectList.add(groupIdLiteral);
+          selectListNames.add(fieldNamesIfNoRewrite.get(groupExprLength + i));
+          groupIdCount++;
+        } else {
+          int ordinal = groupExprLength + i - groupIdCount;
+          selectList.add(relBuilder.field(ordinal));
+          selectListNames.add(aggregateFieldNames.get(ordinal));
+        }
+      }
+      final RelNode project = relBuilder.project(selectList)
+          .rename(selectListNames).build();
+      projects.add(project);
+    }
+    if (projects.size() == 1) {
+      return projects.get(0);
 
 Review comment:
   Please add a brief comment on why is this present.
   It is crystal clear that the code is returning the first item, however, it 
takes time to understand *why* is this needed.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to