[ 
https://issues.apache.org/jira/browse/DRILL-3962?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=18031107#comment-18031107
 ] 

ASF GitHub Bot commented on DRILL-3962:
---------------------------------------

rymarm commented on code in PR #3026:
URL: https://github.com/apache/drill/pull/3026#discussion_r2444421971


##########
exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillAggregateExpandGroupingSetsRule.java:
##########


Review Comment:
   How about dividing the `onMatch` method to even more methods. And 
   
   Something like this:
   ```java
   /**
    * Planner rule that expands GROUPING SETS, ROLLUP, and CUBE into a UNION ALL
    * of multiple aggregates, each with a single grouping set.
    *
    * This rule converts:
    *   SELECT a, b, SUM(c) FROM t GROUP BY GROUPING SETS ((a, b), (a), ())
    *
    * Into:
    *   SELECT a, b, SUM(c), 0 AS $g FROM t GROUP BY a, b
    *   UNION ALL
    *   SELECT a, null, SUM(c), 1 AS $g FROM t GROUP BY a
    *   UNION ALL
    *   SELECT null, null, SUM(c), 3 AS $g FROM t GROUP BY ()
    *
    * The $g column is the grouping ID that can be used by GROUPING() and 
GROUPING_ID() functions.
    * Currently, the $g column is generated internally but stripped from the 
final output.
    */
   public class DrillAggregateExpandGroupingSetsRule extends RelOptRule {
   
     public static final DrillAggregateExpandGroupingSetsRule INSTANCE =
         new DrillAggregateExpandGroupingSetsRule();
     public static final String GROUPING_ID_COLUMN_NAME = "$g";
     public static final String GROUP_ID_COLUMN_NAME = "$group_id";
     public static final String EXPRESSION_COLUMN_PLACEHOLDER = "EXPR$";
   
     private DrillAggregateExpandGroupingSetsRule() {
       super(operand(Aggregate.class, any()), DrillRelFactories.LOGICAL_BUILDER,
           "DrillAggregateExpandGroupingSetsRule");
     }
   
     @Override
     public boolean matches(RelOptRuleCall call) {
       final Aggregate aggregate = call.rel(0);
       return aggregate.getGroupSets().size() > 1
           && (aggregate instanceof DrillAggregateRel || aggregate instanceof 
LogicalAggregate);
     }
   
     @Override
     public void onMatch(RelOptRuleCall call) {
       final Aggregate aggregate = call.rel(0);
       final RelOptCluster cluster = aggregate.getCluster();
   
       GroupingFunctionAnalysis analysis = 
analyzeGroupingFunctions(aggregate.getAggCallList());
       GroupingSetOrderingResult ordering = 
sortAndAssignGroupIds(aggregate.getGroupSets());
   
       List<RelNode> perGroupAggregates = new ArrayList<>();
       for (int i = 0; i < ordering.sortedGroupSets.size(); i++) {
         perGroupAggregates.add(
             createAggregateForGroupingSet(call, aggregate, 
ordering.sortedGroupSets.get(i),
                 ordering.groupIds.get(i), analysis.regularAggCalls));
       }
   
       RelNode unionResult = buildUnion(cluster, perGroupAggregates);
       RelNode result = buildFinalProject(call, unionResult, aggregate, 
analysis);
   
       call.transformTo(result);
     }
     
     /**
      * Encapsulates analysis results of aggregate calls to determine
      * which are regular aggregates and which are grouping-related
      * functions (GROUPING, GROUPING_ID, GROUP_ID).
      */
     private static class GroupingFunctionAnalysis {
       final boolean hasGroupingFunctions;
       final List<AggregateCall> regularAggCalls;
       final List<AggregateCall> groupingFunctionCalls;
       final List<Integer> groupingFunctionPositions;
   
       GroupingFunctionAnalysis(List<AggregateCall> regularAggCalls,
           List<AggregateCall> groupingFunctionCalls,
           List<Integer> groupingFunctionPositions) {
         this.hasGroupingFunctions = !groupingFunctionPositions.isEmpty();
         this.regularAggCalls = regularAggCalls;
         this.groupingFunctionCalls = groupingFunctionCalls;
         this.groupingFunctionPositions = groupingFunctionPositions;
       }
     }
   
     /**
      * Holds the sorted grouping sets (largest first) and their assigned group 
IDs.
      */
     private static class GroupingSetOrderingResult {
       final List<ImmutableBitSet> sortedGroupSets;
       final List<Integer> groupIds;
       GroupingSetOrderingResult(List<ImmutableBitSet> sortedGroupSets, 
List<Integer> groupIds) {
         this.sortedGroupSets = sortedGroupSets;
         this.groupIds = groupIds;
       }
     }
   
     /**
      * Analyzes aggregate calls to identify which ones are GROUPING-related 
functions.
      *
      * @param aggCalls list of aggregate calls in the original aggregate
      * @return structure classifying grouping and non-grouping calls
      */
     private GroupingFunctionAnalysis 
analyzeGroupingFunctions(List<AggregateCall> aggCalls) {
       List<AggregateCall> regularAggCalls = new ArrayList<>();
       List<AggregateCall> groupingFunctionCalls = new ArrayList<>();
       List<Integer> groupingFunctionPositions = new ArrayList<>();
   
       for (int i = 0; i < aggCalls.size(); i++) {
         AggregateCall aggCall = aggCalls.get(i);
         SqlKind kind = aggCall.getAggregation().getKind();
         switch (kind) {
         case SqlKind.GROUPING:
         case SqlKind.GROUPING_ID:
         case SqlKind.GROUP_ID:
           groupingFunctionPositions.add(i);
           groupingFunctionCalls.add(aggCall);
           break;
         default:
           regularAggCalls.add(aggCall);
         }
       }
   
       return new GroupingFunctionAnalysis(regularAggCalls,
           groupingFunctionCalls, groupingFunctionPositions);
     }
   
     /**
      * Sorts grouping sets by decreasing cardinality and assigns a unique 
group ID
      * for each occurrence. Group IDs are used to distinguish identical sets 
when needed.
      */
     private GroupingSetOrderingResult 
sortAndAssignGroupIds(List<ImmutableBitSet> groupSets) {
       List<ImmutableBitSet> sortedGroupSets = new ArrayList<>(groupSets);
       sortedGroupSets.sort((a, b) -> Integer.compare(b.cardinality(), 
a.cardinality()));
   
       Map<ImmutableBitSet, Integer> groupSetOccurrences = new HashMap<>();
       List<Integer> groupIds = new ArrayList<>();
   
       for (ImmutableBitSet groupSet : sortedGroupSets) {
         int groupId = groupSetOccurrences.getOrDefault(groupSet, 0);
         groupIds.add(groupId);
         groupSetOccurrences.put(groupSet, groupId + 1);
       }
   
       return new GroupingSetOrderingResult(sortedGroupSets, groupIds);
     }
   
     /**
      * Creates a single-grouping-set aggregate and adds a projection
      * with null-padding and grouping ID columns ($g and $group_id).
      */
     private RelNode createAggregateForGroupingSet(
         RelOptRuleCall call,
         Aggregate originalAgg,
         ImmutableBitSet groupSet,
         int groupId,
         List<AggregateCall> regularAggCalls) {
   
       ImmutableBitSet fullGroupSet = aggregate.getGroupSet();
       RelOptCluster cluster = originalAgg.getCluster();
       RexBuilder rexBuilder = cluster.getRexBuilder();
       RelDataTypeFactory typeFactory = cluster.getTypeFactory();
       RelNode input = originalAgg.getInput();
   
       Aggregate newAggregate;
       if (originalAgg instanceof DrillAggregateRel) {
         newAggregate = new DrillAggregateRel(cluster, 
originalAgg.getTraitSet(), input,
             groupSet, ImmutableList.of(groupSet), regularAggCalls);
       } else {
         newAggregate = originalAgg.copy(originalAgg.getTraitSet(), input, 
groupSet,
             ImmutableList.of(groupSet), regularAggCalls);
       }
   
       List<RexNode> projects = new ArrayList<>();
       List<String> fieldNames = new ArrayList<>();
       int aggOutputIdx = 0;
       int outputColIdx = 0;
   
       // Populate grouping columns (nulls for omitted columns)
       for (int col : fullGroupSet) {
         if (groupSet.get(col)) {
           projects.add(rexBuilder.makeInputRef(newAggregate, aggOutputIdx++));
         } else {
           RelDataType nullType = 
originalAgg.getRowType().getFieldList().get(outputColIdx).getType();
           projects.add(rexBuilder.makeNullLiteral(nullType));
         }
         
fieldNames.add(originalAgg.getRowType().getFieldList().get(outputColIdx++).getName());
       }
   
       // Add regular aggregates
       for (AggregateCall regCall : regularAggCalls) {
         projects.add(rexBuilder.makeInputRef(newAggregate, aggOutputIdx++));
         fieldNames.add(regCall.getName() != null ? regCall.getName() : "agg$" 
+ aggOutputIdx);
       }
   
       // Add grouping ID ($g)
       int groupingId = computeGroupingId(fullGroupSet, groupSet);
       projects.add(rexBuilder.makeLiteral(groupingId,
           typeFactory.createSqlType(SqlTypeName.INTEGER), true));
       fieldNames.add(GROUPING_ID_COLUMN_NAME);
   
       // Add group ID ($group_id)
       projects.add(rexBuilder.makeLiteral(groupId,
           typeFactory.createSqlType(SqlTypeName.INTEGER), true));
       fieldNames.add(GROUP_ID_COLUMN_NAME);
   
       return call.builder().push(newAggregate).project(projects, fieldNames, 
false).build();
     }
   
     private int computeGroupingId(ImmutableBitSet fullGroupSet, 
ImmutableBitSet groupSet) {
       int id = 0;
       int bit = 0;
       for (int col : fullGroupSet) {
         if (!groupSet.get(col)) id |= (1 << bit);
         bit++;
       }
       return id;
     }
   
     /**
      * Combines all per-grouping-set aggregates into a single {@link 
DrillUnionRel}.
      */
     private RelNode buildUnion(RelOptCluster cluster, List<RelNode> 
aggregates) {
       if (aggregates.size() == 1) {
         return aggregates.get(0);
       }
       try {
         List<RelNode> convertedInputs = new ArrayList<>();
         for (RelNode agg : aggregates) {
           convertedInputs.add(convert(agg, 
agg.getTraitSet().plus(DrillRel.DRILL_LOGICAL).simplify()));
         }
         return new DrillUnionRel(cluster,
             cluster.traitSet().plus(DrillRel.DRILL_LOGICAL),
             convertedInputs,
             true,
             true,
             true);
       } catch (InvalidRelException e) {
         throw new RuntimeException("Failed to create DrillUnionRel", e);
       }
     }
   
     /**
      * Constructs the final projection after the UNION, restoring
      * the original output order and evaluating GROUPING(), GROUPING_ID(), and 
GROUP_ID().
      */
     private RelNode buildFinalProject(
         RelOptRuleCall call,
         RelNode unionResult,
         Aggregate aggregate,
         GroupingFunctionAnalysis analysis) {
   
       RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
       RelDataTypeFactory typeFactory = aggregate.getCluster().getTypeFactory();
       ImmutableBitSet fullGroupSet = aggregate.getGroupSet();
       List<RexNode> finalProjects = new ArrayList<>();
       List<String> finalFieldNames = new ArrayList<>();
       int numFields = unionResult.getRowType().getFieldCount();
   
       for (int i = 0; i < fullGroupSet.cardinality(); i++) {
         finalProjects.add(rexBuilder.makeInputRef(unionResult, i));
         
finalFieldNames.add(unionResult.getRowType().getFieldList().get(i).getName());
       }
   
       if (analysis.hasGroupingFunctions) {
         RexNode gColumnRef = rexBuilder.makeInputRef(unionResult, numFields - 
2);
         RexNode groupIdColumnRef = rexBuilder.makeInputRef(unionResult, 
numFields - 1);
         Map<Integer, AggregateCall> groupingFuncMap = new HashMap<>();
         for (int i = 0; i < analysis.groupingFunctionPositions.size(); i++) {
           groupingFuncMap.put(analysis.groupingFunctionPositions.get(i),
               analysis.groupingFunctionCalls.get(i));
         }
   
         int regularAggIndex = fullGroupSet.cardinality();
         for (int origPos = 0; origPos < aggregate.getAggCallList().size(); 
origPos++) {
           if (groupingFuncMap.containsKey(origPos)) {
             AggregateCall groupingCall = groupingFuncMap.get(origPos);
             String funcName = groupingCall.getAggregation().getName();
             if ("GROUPING".equals(funcName)) {
               processGrouping(groupingCall, fullGroupSet, rexBuilder, 
typeFactory,
                   gColumnRef, finalProjects, finalFieldNames);
             } else if ("GROUPING_ID".equals(funcName)) {
               processGroupingId(groupingCall, fullGroupSet, rexBuilder, 
typeFactory,
                   gColumnRef, finalProjects, finalFieldNames);
             } else if ("GROUP_ID".equals(funcName)) {
               finalProjects.add(groupIdColumnRef);
               String fieldName = groupingCall.getName() != null
                   ? groupingCall.getName()
                   : EXPRESSION_COLUMN_PLACEHOLDER + finalFieldNames.size();
               finalFieldNames.add(fieldName);
             }
           } else {
             finalProjects.add(rexBuilder.makeInputRef(unionResult, 
regularAggIndex));
             
finalFieldNames.add(unionResult.getRowType().getFieldList().get(regularAggIndex).getName());
             regularAggIndex++;
           }
         }
       } else {
         for (int i = fullGroupSet.cardinality(); i < numFields - 2; i++) {
           finalProjects.add(rexBuilder.makeInputRef(unionResult, i));
           
finalFieldNames.add(unionResult.getRowType().getFieldList().get(i).getName());
         }
       }
   
       return call.builder().push(unionResult).project(finalProjects, 
finalFieldNames, false).build();
     }
   
     /**
      * Builds the Rex expression that implements {@code GROUPING(column)}.
      */
     private void processGrouping(AggregateCall groupingCall,
         ImmutableBitSet fullGroupSet,
         RexBuilder rexBuilder,
         RelDataTypeFactory typeFactory,
         RexNode gColumnRef,
         List<RexNode> finalProjects,
         List<String> finalFieldNames) {
   
       if (groupingCall.getArgList().size() != 1) {
         throw new RuntimeException("GROUPING() expects exactly 1 argument");
       }
   
       int columnIndex = groupingCall.getArgList().get(0);
       int bitPosition = 0;
       for (int col : fullGroupSet) {
         if (col == columnIndex) break;
         bitPosition++;
       }
   
       RexNode divisor = rexBuilder.makeLiteral(
           1 << bitPosition, typeFactory.createSqlType(SqlTypeName.INTEGER), 
true);
   
       RexNode divided = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, 
gColumnRef, divisor);
       RexNode extractBit = rexBuilder.makeCall(SqlStdOperatorTable.MOD, 
divided,
           rexBuilder.makeLiteral(2, 
typeFactory.createSqlType(SqlTypeName.INTEGER), true));
   
       finalProjects.add(extractBit);
       String fieldName = groupingCall.getName() != null
           ? groupingCall.getName()
           : "EXPR$" + finalFieldNames.size();
       finalFieldNames.add(fieldName);
     }
   
     /**
      * Builds the Rex expression that implements {@code GROUPING_ID(column, 
...)}.
      */
     private void processGroupingId(AggregateCall groupingCall,
         ImmutableBitSet fullGroupSet,
         RexBuilder rexBuilder,
         RelDataTypeFactory typeFactory,
         RexNode gColumnRef,
         List<RexNode> finalProjects,
         List<String> finalFieldNames) {
   
       if (groupingCall.getArgList().isEmpty()) {
         throw new RuntimeException("GROUPING_ID() expects at least one 
argument");
       }
   
       RexNode result = null;
       for (int i = 0; i < groupingCall.getArgList().size(); i++) {
         int columnIndex = groupingCall.getArgList().get(i);
         int bitPosition = 0;
         for (int col : fullGroupSet) {
           if (col == columnIndex) break;
           bitPosition++;
         }
   
         RexNode divisor = rexBuilder.makeLiteral(1 << bitPosition,
             typeFactory.createSqlType(SqlTypeName.INTEGER), true);
   
         RexNode divided = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, 
gColumnRef, divisor);
         RexNode extractBit = rexBuilder.makeCall(SqlStdOperatorTable.MOD, 
divided,
             rexBuilder.makeLiteral(2, 
typeFactory.createSqlType(SqlTypeName.INTEGER), true));
   
         int resultBitPos = groupingCall.getArgList().size() - 1 - i;
         RexNode bitInPosition = (resultBitPos > 0)
             ? rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, extractBit,
             rexBuilder.makeLiteral(1 << resultBitPos,
                 typeFactory.createSqlType(SqlTypeName.INTEGER), true))
             : extractBit;
   
         result = (result == null)
             ? bitInPosition
             : rexBuilder.makeCall(SqlStdOperatorTable.PLUS, result, 
bitInPosition);
       }
   
       finalProjects.add(result);
       String fieldName = groupingCall.getName() != null
           ? groupingCall.getName()
           : "EXPR$" + finalFieldNames.size();
       finalFieldNames.add(fieldName);
     }
   }
   ```





> Add support of ROLLUP, CUBE, GROUPING SETS, GROUPING, GROUPING_ID, GROUP_ID 
> support
> -----------------------------------------------------------------------------------
>
>                 Key: DRILL-3962
>                 URL: https://issues.apache.org/jira/browse/DRILL-3962
>             Project: Apache Drill
>          Issue Type: New Feature
>            Reporter: Jinfeng Ni
>            Assignee: Charles Givre
>            Priority: Major
>
> These functions are important for BI analytical workload.  Currently, Calcite 
> supports those functions, but neither the planning or execution in Drill 
> supports those functions. 
> DRILL-3802 blocks those functions in Drill planning. But we should provide 
> the support for those functions in both planning and execution of Drill. 



--
This message was sent by Atlassian Jira
(v8.20.10#820010)

Reply via email to