[
https://issues.apache.org/jira/browse/DRILL-3962?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=18031107#comment-18031107
]
ASF GitHub Bot commented on DRILL-3962:
---------------------------------------
rymarm commented on code in PR #3026:
URL: https://github.com/apache/drill/pull/3026#discussion_r2444421971
##########
exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillAggregateExpandGroupingSetsRule.java:
##########
Review Comment:
How about dividing the `onMatch` method to even more methods. And
Something like this:
```java
/**
* Planner rule that expands GROUPING SETS, ROLLUP, and CUBE into a UNION ALL
* of multiple aggregates, each with a single grouping set.
*
* This rule converts:
* SELECT a, b, SUM(c) FROM t GROUP BY GROUPING SETS ((a, b), (a), ())
*
* Into:
* SELECT a, b, SUM(c), 0 AS $g FROM t GROUP BY a, b
* UNION ALL
* SELECT a, null, SUM(c), 1 AS $g FROM t GROUP BY a
* UNION ALL
* SELECT null, null, SUM(c), 3 AS $g FROM t GROUP BY ()
*
* The $g column is the grouping ID that can be used by GROUPING() and
GROUPING_ID() functions.
* Currently, the $g column is generated internally but stripped from the
final output.
*/
public class DrillAggregateExpandGroupingSetsRule extends RelOptRule {
public static final DrillAggregateExpandGroupingSetsRule INSTANCE =
new DrillAggregateExpandGroupingSetsRule();
public static final String GROUPING_ID_COLUMN_NAME = "$g";
public static final String GROUP_ID_COLUMN_NAME = "$group_id";
public static final String EXPRESSION_COLUMN_PLACEHOLDER = "EXPR$";
private DrillAggregateExpandGroupingSetsRule() {
super(operand(Aggregate.class, any()), DrillRelFactories.LOGICAL_BUILDER,
"DrillAggregateExpandGroupingSetsRule");
}
@Override
public boolean matches(RelOptRuleCall call) {
final Aggregate aggregate = call.rel(0);
return aggregate.getGroupSets().size() > 1
&& (aggregate instanceof DrillAggregateRel || aggregate instanceof
LogicalAggregate);
}
@Override
public void onMatch(RelOptRuleCall call) {
final Aggregate aggregate = call.rel(0);
final RelOptCluster cluster = aggregate.getCluster();
GroupingFunctionAnalysis analysis =
analyzeGroupingFunctions(aggregate.getAggCallList());
GroupingSetOrderingResult ordering =
sortAndAssignGroupIds(aggregate.getGroupSets());
List<RelNode> perGroupAggregates = new ArrayList<>();
for (int i = 0; i < ordering.sortedGroupSets.size(); i++) {
perGroupAggregates.add(
createAggregateForGroupingSet(call, aggregate,
ordering.sortedGroupSets.get(i),
ordering.groupIds.get(i), analysis.regularAggCalls));
}
RelNode unionResult = buildUnion(cluster, perGroupAggregates);
RelNode result = buildFinalProject(call, unionResult, aggregate,
analysis);
call.transformTo(result);
}
/**
* Encapsulates analysis results of aggregate calls to determine
* which are regular aggregates and which are grouping-related
* functions (GROUPING, GROUPING_ID, GROUP_ID).
*/
private static class GroupingFunctionAnalysis {
final boolean hasGroupingFunctions;
final List<AggregateCall> regularAggCalls;
final List<AggregateCall> groupingFunctionCalls;
final List<Integer> groupingFunctionPositions;
GroupingFunctionAnalysis(List<AggregateCall> regularAggCalls,
List<AggregateCall> groupingFunctionCalls,
List<Integer> groupingFunctionPositions) {
this.hasGroupingFunctions = !groupingFunctionPositions.isEmpty();
this.regularAggCalls = regularAggCalls;
this.groupingFunctionCalls = groupingFunctionCalls;
this.groupingFunctionPositions = groupingFunctionPositions;
}
}
/**
* Holds the sorted grouping sets (largest first) and their assigned group
IDs.
*/
private static class GroupingSetOrderingResult {
final List<ImmutableBitSet> sortedGroupSets;
final List<Integer> groupIds;
GroupingSetOrderingResult(List<ImmutableBitSet> sortedGroupSets,
List<Integer> groupIds) {
this.sortedGroupSets = sortedGroupSets;
this.groupIds = groupIds;
}
}
/**
* Analyzes aggregate calls to identify which ones are GROUPING-related
functions.
*
* @param aggCalls list of aggregate calls in the original aggregate
* @return structure classifying grouping and non-grouping calls
*/
private GroupingFunctionAnalysis
analyzeGroupingFunctions(List<AggregateCall> aggCalls) {
List<AggregateCall> regularAggCalls = new ArrayList<>();
List<AggregateCall> groupingFunctionCalls = new ArrayList<>();
List<Integer> groupingFunctionPositions = new ArrayList<>();
for (int i = 0; i < aggCalls.size(); i++) {
AggregateCall aggCall = aggCalls.get(i);
SqlKind kind = aggCall.getAggregation().getKind();
switch (kind) {
case SqlKind.GROUPING:
case SqlKind.GROUPING_ID:
case SqlKind.GROUP_ID:
groupingFunctionPositions.add(i);
groupingFunctionCalls.add(aggCall);
break;
default:
regularAggCalls.add(aggCall);
}
}
return new GroupingFunctionAnalysis(regularAggCalls,
groupingFunctionCalls, groupingFunctionPositions);
}
/**
* Sorts grouping sets by decreasing cardinality and assigns a unique
group ID
* for each occurrence. Group IDs are used to distinguish identical sets
when needed.
*/
private GroupingSetOrderingResult
sortAndAssignGroupIds(List<ImmutableBitSet> groupSets) {
List<ImmutableBitSet> sortedGroupSets = new ArrayList<>(groupSets);
sortedGroupSets.sort((a, b) -> Integer.compare(b.cardinality(),
a.cardinality()));
Map<ImmutableBitSet, Integer> groupSetOccurrences = new HashMap<>();
List<Integer> groupIds = new ArrayList<>();
for (ImmutableBitSet groupSet : sortedGroupSets) {
int groupId = groupSetOccurrences.getOrDefault(groupSet, 0);
groupIds.add(groupId);
groupSetOccurrences.put(groupSet, groupId + 1);
}
return new GroupingSetOrderingResult(sortedGroupSets, groupIds);
}
/**
* Creates a single-grouping-set aggregate and adds a projection
* with null-padding and grouping ID columns ($g and $group_id).
*/
private RelNode createAggregateForGroupingSet(
RelOptRuleCall call,
Aggregate originalAgg,
ImmutableBitSet groupSet,
int groupId,
List<AggregateCall> regularAggCalls) {
ImmutableBitSet fullGroupSet = aggregate.getGroupSet();
RelOptCluster cluster = originalAgg.getCluster();
RexBuilder rexBuilder = cluster.getRexBuilder();
RelDataTypeFactory typeFactory = cluster.getTypeFactory();
RelNode input = originalAgg.getInput();
Aggregate newAggregate;
if (originalAgg instanceof DrillAggregateRel) {
newAggregate = new DrillAggregateRel(cluster,
originalAgg.getTraitSet(), input,
groupSet, ImmutableList.of(groupSet), regularAggCalls);
} else {
newAggregate = originalAgg.copy(originalAgg.getTraitSet(), input,
groupSet,
ImmutableList.of(groupSet), regularAggCalls);
}
List<RexNode> projects = new ArrayList<>();
List<String> fieldNames = new ArrayList<>();
int aggOutputIdx = 0;
int outputColIdx = 0;
// Populate grouping columns (nulls for omitted columns)
for (int col : fullGroupSet) {
if (groupSet.get(col)) {
projects.add(rexBuilder.makeInputRef(newAggregate, aggOutputIdx++));
} else {
RelDataType nullType =
originalAgg.getRowType().getFieldList().get(outputColIdx).getType();
projects.add(rexBuilder.makeNullLiteral(nullType));
}
fieldNames.add(originalAgg.getRowType().getFieldList().get(outputColIdx++).getName());
}
// Add regular aggregates
for (AggregateCall regCall : regularAggCalls) {
projects.add(rexBuilder.makeInputRef(newAggregate, aggOutputIdx++));
fieldNames.add(regCall.getName() != null ? regCall.getName() : "agg$"
+ aggOutputIdx);
}
// Add grouping ID ($g)
int groupingId = computeGroupingId(fullGroupSet, groupSet);
projects.add(rexBuilder.makeLiteral(groupingId,
typeFactory.createSqlType(SqlTypeName.INTEGER), true));
fieldNames.add(GROUPING_ID_COLUMN_NAME);
// Add group ID ($group_id)
projects.add(rexBuilder.makeLiteral(groupId,
typeFactory.createSqlType(SqlTypeName.INTEGER), true));
fieldNames.add(GROUP_ID_COLUMN_NAME);
return call.builder().push(newAggregate).project(projects, fieldNames,
false).build();
}
private int computeGroupingId(ImmutableBitSet fullGroupSet,
ImmutableBitSet groupSet) {
int id = 0;
int bit = 0;
for (int col : fullGroupSet) {
if (!groupSet.get(col)) id |= (1 << bit);
bit++;
}
return id;
}
/**
* Combines all per-grouping-set aggregates into a single {@link
DrillUnionRel}.
*/
private RelNode buildUnion(RelOptCluster cluster, List<RelNode>
aggregates) {
if (aggregates.size() == 1) {
return aggregates.get(0);
}
try {
List<RelNode> convertedInputs = new ArrayList<>();
for (RelNode agg : aggregates) {
convertedInputs.add(convert(agg,
agg.getTraitSet().plus(DrillRel.DRILL_LOGICAL).simplify()));
}
return new DrillUnionRel(cluster,
cluster.traitSet().plus(DrillRel.DRILL_LOGICAL),
convertedInputs,
true,
true,
true);
} catch (InvalidRelException e) {
throw new RuntimeException("Failed to create DrillUnionRel", e);
}
}
/**
* Constructs the final projection after the UNION, restoring
* the original output order and evaluating GROUPING(), GROUPING_ID(), and
GROUP_ID().
*/
private RelNode buildFinalProject(
RelOptRuleCall call,
RelNode unionResult,
Aggregate aggregate,
GroupingFunctionAnalysis analysis) {
RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
RelDataTypeFactory typeFactory = aggregate.getCluster().getTypeFactory();
ImmutableBitSet fullGroupSet = aggregate.getGroupSet();
List<RexNode> finalProjects = new ArrayList<>();
List<String> finalFieldNames = new ArrayList<>();
int numFields = unionResult.getRowType().getFieldCount();
for (int i = 0; i < fullGroupSet.cardinality(); i++) {
finalProjects.add(rexBuilder.makeInputRef(unionResult, i));
finalFieldNames.add(unionResult.getRowType().getFieldList().get(i).getName());
}
if (analysis.hasGroupingFunctions) {
RexNode gColumnRef = rexBuilder.makeInputRef(unionResult, numFields -
2);
RexNode groupIdColumnRef = rexBuilder.makeInputRef(unionResult,
numFields - 1);
Map<Integer, AggregateCall> groupingFuncMap = new HashMap<>();
for (int i = 0; i < analysis.groupingFunctionPositions.size(); i++) {
groupingFuncMap.put(analysis.groupingFunctionPositions.get(i),
analysis.groupingFunctionCalls.get(i));
}
int regularAggIndex = fullGroupSet.cardinality();
for (int origPos = 0; origPos < aggregate.getAggCallList().size();
origPos++) {
if (groupingFuncMap.containsKey(origPos)) {
AggregateCall groupingCall = groupingFuncMap.get(origPos);
String funcName = groupingCall.getAggregation().getName();
if ("GROUPING".equals(funcName)) {
processGrouping(groupingCall, fullGroupSet, rexBuilder,
typeFactory,
gColumnRef, finalProjects, finalFieldNames);
} else if ("GROUPING_ID".equals(funcName)) {
processGroupingId(groupingCall, fullGroupSet, rexBuilder,
typeFactory,
gColumnRef, finalProjects, finalFieldNames);
} else if ("GROUP_ID".equals(funcName)) {
finalProjects.add(groupIdColumnRef);
String fieldName = groupingCall.getName() != null
? groupingCall.getName()
: EXPRESSION_COLUMN_PLACEHOLDER + finalFieldNames.size();
finalFieldNames.add(fieldName);
}
} else {
finalProjects.add(rexBuilder.makeInputRef(unionResult,
regularAggIndex));
finalFieldNames.add(unionResult.getRowType().getFieldList().get(regularAggIndex).getName());
regularAggIndex++;
}
}
} else {
for (int i = fullGroupSet.cardinality(); i < numFields - 2; i++) {
finalProjects.add(rexBuilder.makeInputRef(unionResult, i));
finalFieldNames.add(unionResult.getRowType().getFieldList().get(i).getName());
}
}
return call.builder().push(unionResult).project(finalProjects,
finalFieldNames, false).build();
}
/**
* Builds the Rex expression that implements {@code GROUPING(column)}.
*/
private void processGrouping(AggregateCall groupingCall,
ImmutableBitSet fullGroupSet,
RexBuilder rexBuilder,
RelDataTypeFactory typeFactory,
RexNode gColumnRef,
List<RexNode> finalProjects,
List<String> finalFieldNames) {
if (groupingCall.getArgList().size() != 1) {
throw new RuntimeException("GROUPING() expects exactly 1 argument");
}
int columnIndex = groupingCall.getArgList().get(0);
int bitPosition = 0;
for (int col : fullGroupSet) {
if (col == columnIndex) break;
bitPosition++;
}
RexNode divisor = rexBuilder.makeLiteral(
1 << bitPosition, typeFactory.createSqlType(SqlTypeName.INTEGER),
true);
RexNode divided = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE,
gColumnRef, divisor);
RexNode extractBit = rexBuilder.makeCall(SqlStdOperatorTable.MOD,
divided,
rexBuilder.makeLiteral(2,
typeFactory.createSqlType(SqlTypeName.INTEGER), true));
finalProjects.add(extractBit);
String fieldName = groupingCall.getName() != null
? groupingCall.getName()
: "EXPR$" + finalFieldNames.size();
finalFieldNames.add(fieldName);
}
/**
* Builds the Rex expression that implements {@code GROUPING_ID(column,
...)}.
*/
private void processGroupingId(AggregateCall groupingCall,
ImmutableBitSet fullGroupSet,
RexBuilder rexBuilder,
RelDataTypeFactory typeFactory,
RexNode gColumnRef,
List<RexNode> finalProjects,
List<String> finalFieldNames) {
if (groupingCall.getArgList().isEmpty()) {
throw new RuntimeException("GROUPING_ID() expects at least one
argument");
}
RexNode result = null;
for (int i = 0; i < groupingCall.getArgList().size(); i++) {
int columnIndex = groupingCall.getArgList().get(i);
int bitPosition = 0;
for (int col : fullGroupSet) {
if (col == columnIndex) break;
bitPosition++;
}
RexNode divisor = rexBuilder.makeLiteral(1 << bitPosition,
typeFactory.createSqlType(SqlTypeName.INTEGER), true);
RexNode divided = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE,
gColumnRef, divisor);
RexNode extractBit = rexBuilder.makeCall(SqlStdOperatorTable.MOD,
divided,
rexBuilder.makeLiteral(2,
typeFactory.createSqlType(SqlTypeName.INTEGER), true));
int resultBitPos = groupingCall.getArgList().size() - 1 - i;
RexNode bitInPosition = (resultBitPos > 0)
? rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, extractBit,
rexBuilder.makeLiteral(1 << resultBitPos,
typeFactory.createSqlType(SqlTypeName.INTEGER), true))
: extractBit;
result = (result == null)
? bitInPosition
: rexBuilder.makeCall(SqlStdOperatorTable.PLUS, result,
bitInPosition);
}
finalProjects.add(result);
String fieldName = groupingCall.getName() != null
? groupingCall.getName()
: "EXPR$" + finalFieldNames.size();
finalFieldNames.add(fieldName);
}
}
```
> Add support of ROLLUP, CUBE, GROUPING SETS, GROUPING, GROUPING_ID, GROUP_ID
> support
> -----------------------------------------------------------------------------------
>
> Key: DRILL-3962
> URL: https://issues.apache.org/jira/browse/DRILL-3962
> Project: Apache Drill
> Issue Type: New Feature
> Reporter: Jinfeng Ni
> Assignee: Charles Givre
> Priority: Major
>
> These functions are important for BI analytical workload. Currently, Calcite
> supports those functions, but neither the planning or execution in Drill
> supports those functions.
> DRILL-3802 blocks those functions in Drill planning. But we should provide
> the support for those functions in both planning and execution of Drill.
--
This message was sent by Atlassian Jira
(v8.20.10#820010)