This is an automated email from the ASF dual-hosted git repository.
cgivre pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/drill.git
The following commit(s) were added to refs/heads/master by this push:
new 261f526b97 DRILL-3962: Add Support For ROLLUP, CUBE, GROUPING SETS,
GROUPING, GROUPING_ID, GROUP_ID (#3026)
261f526b97 is described below
commit 261f526b97761a29da12a0b39770dd8bff1c422d
Author: Charles S. Givre <[email protected]>
AuthorDate: Fri Oct 24 10:30:13 2025 -0400
DRILL-3962: Add Support For ROLLUP, CUBE, GROUPING SETS, GROUPING,
GROUPING_ID, GROUP_ID (#3026)
---
.../src/main/codegen/data/DateIntervalFunc.tdd | 2 +
.../drill/exec/expr/fn/impl/GroupingFunctions.java | 79 ++++
.../drill/exec/physical/config/UnionAll.java | 17 +-
.../physical/impl/union/UnionAllRecordBatch.java | 70 ++-
.../apache/drill/exec/planner/PlannerPhase.java | 5 +
.../apache/drill/exec/planner/RuleInstance.java | 4 +
.../DrillAggregateExpandGroupingSetsRule.java | 481 +++++++++++++++++++++
.../exec/planner/logical/DrillAggregateRule.java | 6 +
.../drill/exec/planner/logical/DrillOptiq.java | 5 +
.../exec/planner/logical/DrillUnionAllRule.java | 18 +-
.../drill/exec/planner/logical/DrillUnionRel.java | 14 +-
.../drill/exec/planner/physical/HashAggPrule.java | 6 +
.../exec/planner/physical/StreamAggPrule.java | 6 +
.../drill/exec/planner/physical/UnionAllPrel.java | 13 +-
.../drill/exec/planner/physical/UnionAllPrule.java | 3 +-
.../sql/parser/UnsupportedOperatorsVisitor.java | 90 +---
.../apache/drill/TestDisabledFunctionality.java | 71 ---
.../org/apache/drill/TestGroupingSetsResults.java | 372 ++++++++++++++++
18 files changed, 1089 insertions(+), 173 deletions(-)
diff --git a/exec/java-exec/src/main/codegen/data/DateIntervalFunc.tdd
b/exec/java-exec/src/main/codegen/data/DateIntervalFunc.tdd
index 12d66b284b..fe181b2feb 100644
--- a/exec/java-exec/src/main/codegen/data/DateIntervalFunc.tdd
+++ b/exec/java-exec/src/main/codegen/data/DateIntervalFunc.tdd
@@ -22,6 +22,8 @@
{truncInputTypes: ["Date", "TimeStamp", "Time", "Interval", "IntervalDay",
"IntervalYear"] },
{truncUnits : ["Second", "Minute", "Hour", "Day", "Month", "Year", "Week",
"Quarter", "Decade", "Century", "Millennium" ] },
{timestampDiffUnits : ["Nanosecond", "Microsecond", "Millisecond",
"Second", "Minute", "Hour", "Day", "Month", "Year", "Week", "Quarter"] },
+ {timestampAddUnits : ["Nanosecond", "Microsecond", "Millisecond",
"Second", "Minute", "Hour", "Day", "Month", "Year", "Week", "Quarter"] },
+ {timestampAddInputTypes : ["Date", "TimeStamp", "Time"] },
{
varCharToDate: [
diff --git
a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/GroupingFunctions.java
b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/GroupingFunctions.java
new file mode 100644
index 0000000000..a2415ab488
--- /dev/null
+++
b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/GroupingFunctions.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.drill.exec.expr.fn.impl;
+
+import org.apache.drill.exec.expr.DrillSimpleFunc;
+import org.apache.drill.exec.expr.annotations.FunctionTemplate;
+import org.apache.drill.exec.expr.annotations.Output;
+import org.apache.drill.exec.expr.annotations.Param;
+import org.apache.drill.exec.expr.holders.IntHolder;
+
+/**
+ * Functions for working with GROUPING SETS, ROLLUP, and CUBE.
+ *
+ * Note: These are internal helper functions. The actual GROUPING() and
GROUPING_ID()
+ * SQL functions need special query rewriting to work correctly with GROUPING
SETS.
+ */
+public class GroupingFunctions {
+
+ /**
+ * GROUPING_ID_INTERNAL - Returns the grouping ID bitmap.
+ * This is an internal function that will be called with the $g column value.
+ */
+ @FunctionTemplate(name = "grouping_id_internal",
+ scope = FunctionTemplate.FunctionScope.SIMPLE,
+ nulls = FunctionTemplate.NullHandling.NULL_IF_NULL)
+ public static class GroupingIdInternal implements DrillSimpleFunc {
+
+ @Param IntHolder groupingId;
+ @Output IntHolder out;
+
+ public void setup() {
+ }
+
+ public void eval() {
+ out.value = groupingId.value;
+ }
+ }
+
+ /**
+ * GROUPING_INTERNAL - Returns 1 if the specified bit in the grouping ID is
set, 0 otherwise.
+ * This is an internal function that extracts a specific bit from the
grouping ID.
+ *
+ * @param groupingId The grouping ID bitmap ($g column value)
+ * @param bitPosition The bit position to check (0-based)
+ */
+ @FunctionTemplate(name = "grouping_internal",
+ scope = FunctionTemplate.FunctionScope.SIMPLE,
+ nulls = FunctionTemplate.NullHandling.NULL_IF_NULL)
+ public static class GroupingInternal implements DrillSimpleFunc {
+
+ @Param IntHolder groupingId;
+ @Param IntHolder bitPosition;
+ @Output IntHolder out;
+
+ public void setup() {
+ }
+
+ public void eval() {
+ // Extract the bit at bitPosition from groupingId
+ // Bit is 1 if column is NOT in the grouping set (i.e., it's a grouping
NULL)
+ out.value = (groupingId.value >> bitPosition.value) & 1;
+ }
+ }
+}
diff --git
a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/config/UnionAll.java
b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/config/UnionAll.java
index 59b4bfdb09..4bc4fbd51b 100644
---
a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/config/UnionAll.java
+++
b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/config/UnionAll.java
@@ -33,9 +33,22 @@ public class UnionAll extends AbstractMultiple {
public static final String OPERATOR_TYPE = "UNION";
+ private final boolean isGroupingSetsExpansion;
+
@JsonCreator
- public UnionAll(@JsonProperty("children") List<PhysicalOperator> children) {
+ public UnionAll(@JsonProperty("children") List<PhysicalOperator> children,
+ @JsonProperty("isGroupingSetsExpansion") Boolean
isGroupingSetsExpansion) {
super(children);
+ this.isGroupingSetsExpansion = isGroupingSetsExpansion != null ?
isGroupingSetsExpansion : false;
+ }
+
+ public UnionAll(List<PhysicalOperator> children) {
+ this(children, false);
+ }
+
+ @JsonProperty("isGroupingSetsExpansion")
+ public boolean isGroupingSetsExpansion() {
+ return isGroupingSetsExpansion;
}
@Override
@@ -45,7 +58,7 @@ public class UnionAll extends AbstractMultiple {
@Override
public PhysicalOperator getNewWithChildren(List<PhysicalOperator> children) {
- return new UnionAll(children);
+ return new UnionAll(children, isGroupingSetsExpansion);
}
@Override
diff --git
a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/union/UnionAllRecordBatch.java
b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/union/UnionAllRecordBatch.java
index fad14184fa..784e78ec9e 100644
---
a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/union/UnionAllRecordBatch.java
+++
b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/union/UnionAllRecordBatch.java
@@ -17,13 +17,7 @@
*/
package org.apache.drill.exec.physical.impl.union;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.Iterator;
-import java.util.List;
-import java.util.NoSuchElementException;
-import java.util.Stack;
-
+import com.google.common.base.Preconditions;
import org.apache.calcite.util.Pair;
import org.apache.drill.common.exceptions.DrillRuntimeException;
import org.apache.drill.common.expression.ErrorCollector;
@@ -59,10 +53,16 @@ import
org.apache.drill.exec.util.record.RecordBatchStats.RecordBatchIOType;
import org.apache.drill.exec.vector.FixedWidthVector;
import org.apache.drill.exec.vector.SchemaChangeCallBack;
import org.apache.drill.exec.vector.ValueVector;
-import com.google.common.base.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.NoSuchElementException;
+import java.util.Stack;
+
public class UnionAllRecordBatch extends AbstractBinaryRecordBatch<UnionAll> {
private static final Logger logger =
LoggerFactory.getLogger(UnionAllRecordBatch.class);
@@ -278,10 +278,14 @@ public class UnionAllRecordBatch extends
AbstractBinaryRecordBatch<UnionAll> {
final Iterator<MaterializedField> leftIter = leftSchema.iterator();
final Iterator<MaterializedField> rightIter = rightSchema.iterator();
+ logger.debug("UnionAll inferring schema: isGroupingSetsExpansion={}",
popConfig.isGroupingSetsExpansion());
int index = 1;
while (leftIter.hasNext() && rightIter.hasNext()) {
MaterializedField leftField = leftIter.next();
MaterializedField rightField = rightIter.next();
+ logger.debug("Column {}: left='{}' type={}, right='{}' type={}",
+ index, leftField.getName(), leftField.getType().getMinorType(),
+ rightField.getName(), rightField.getType().getMinorType());
if (Types.isSameTypeAndMode(leftField.getType(), rightField.getType())) {
MajorType.Builder builder = MajorType.newBuilder()
@@ -301,15 +305,7 @@ public class UnionAllRecordBatch extends
AbstractBinaryRecordBatch<UnionAll> {
builder.setMinorType(leftField.getType().getMinorType());
builder = Types.calculateTypePrecisionAndScale(leftField.getType(),
rightField.getType(), builder);
} else {
- TypeProtos.MinorType outputMinorType =
TypeCastRules.getLeastRestrictiveType(
- leftField.getType().getMinorType(),
- rightField.getType().getMinorType()
- );
- if (outputMinorType == null) {
- throw new DrillRuntimeException("Type mismatch between " +
leftField.getType().getMinorType().toString() +
- " on the left side and " +
rightField.getType().getMinorType().toString() +
- " on the right side in column " + index + " of UNION ALL");
- }
+ TypeProtos.MinorType outputMinorType =
resolveUnionColumnType(leftField, rightField, index);
builder.setMinorType(outputMinorType);
}
@@ -328,6 +324,46 @@ public class UnionAllRecordBatch extends
AbstractBinaryRecordBatch<UnionAll> {
"Mismatch of column count should have been detected when validating
sqlNode at planning";
}
+ /**
+ * Determines the output type for a UNION ALL column when combining two
types.
+ * <p>
+ * Special handling is applied for GROUPING SETS expansion:
+ * - Drill represents NULL columns as INT during grouping sets expansion.
+ * - If one side is INT (likely a NULL placeholder) and the other is not,
prefer the non-INT type.
+ * <p>
+ * For all other cases, the least restrictive type according to Drill's type
cast rules is returned.
+ *
+ * @param leftField The type of the left column
+ * @param rightField The type of the right column
+ * @param index The column index (for logging)
+ * @return The resolved output type
+ * @throws DrillRuntimeException if types are incompatible
+ */
+ private TypeProtos.MinorType resolveUnionColumnType(MaterializedField
leftField,
+ MaterializedField rightField,
+ int index) {
+ TypeProtos.MinorType leftType = leftField.getType().getMinorType();
+ TypeProtos.MinorType rightType = rightField.getType().getMinorType();
+
+ boolean isGroupingSets = popConfig.isGroupingSetsExpansion();
+ boolean leftIsPlaceholder = leftType == TypeProtos.MinorType.INT &&
rightType != TypeProtos.MinorType.INT;
+ boolean rightIsPlaceholder = rightType == TypeProtos.MinorType.INT &&
leftType != TypeProtos.MinorType.INT;
+
+ if (isGroupingSets && (leftIsPlaceholder || rightIsPlaceholder)) {
+ TypeProtos.MinorType outputType = leftIsPlaceholder ? rightType :
leftType;
+ logger.debug("GROUPING SETS: Preferring {} over INT for column {}",
outputType, index);
+ return outputType;
+ }
+
+ TypeProtos.MinorType outputType =
TypeCastRules.getLeastRestrictiveType(leftType, rightType);
+ if (outputType == null) {
+ throw new DrillRuntimeException("Type mismatch between " + leftType +
+ " and " + rightType + " in column " + index + " of UNION ALL");
+ }
+ logger.debug("Using standard type rules: {} + {} -> {}", leftType,
rightType, outputType);
+ return outputType;
+ }
+
private void inferOutputFieldsOneSide(final BatchSchema schema) {
for (MaterializedField field : schema) {
container.addOrGet(field, callBack);
diff --git
a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java
b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java
index 4a32b3a9bd..a8c5224a23 100644
---
a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java
+++
b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java
@@ -356,6 +356,11 @@ public enum PlannerPhase {
Convert from Calcite Logical to Drill Logical Rules.
*/
RuleInstance.EXPAND_CONVERSION_RULE,
+
+ // Expand GROUPING SETS, ROLLUP, and CUBE BEFORE converting aggregates
to Drill logical operators
+ // This prevents multi-grouping-set aggregates from being converted to
DrillAggregateRel
+ RuleInstance.AGGREGATE_EXPAND_GROUPING_SETS_RULE,
+
DrillScanRule.INSTANCE,
DrillFilterRule.INSTANCE,
DrillProjectRule.INSTANCE,
diff --git
a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/RuleInstance.java
b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/RuleInstance.java
index baa39dba23..a370c64e76 100644
---
a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/RuleInstance.java
+++
b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/RuleInstance.java
@@ -39,6 +39,7 @@ import org.apache.calcite.rel.rules.SemiJoinRule;
import org.apache.calcite.rel.rules.SortRemoveRule;
import org.apache.calcite.rel.rules.SubQueryRemoveRule;
import org.apache.calcite.rel.rules.UnionToDistinctRule;
+import
org.apache.drill.exec.planner.logical.DrillAggregateExpandGroupingSetsRule;
import org.apache.drill.exec.planner.logical.DrillConditions;
import org.apache.drill.exec.planner.logical.DrillRelFactories;
import com.google.common.base.Preconditions;
@@ -107,6 +108,9 @@ public interface RuleInstance {
.withRelBuilderFactory(DrillRelFactories.LOGICAL_BUILDER)
.toRule();
+ RelOptRule AGGREGATE_EXPAND_GROUPING_SETS_RULE =
+ DrillAggregateExpandGroupingSetsRule.INSTANCE;
+
/**
* Instance of the rule that works on logical joins only, and pushes to the
* right.
diff --git
a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillAggregateExpandGroupingSetsRule.java
b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillAggregateExpandGroupingSetsRule.java
new file mode 100644
index 0000000000..9756a40882
--- /dev/null
+++
b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillAggregateExpandGroupingSetsRule.java
@@ -0,0 +1,481 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.drill.exec.planner.logical;
+
+import com.google.common.collect.ImmutableList;
+import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.rel.InvalidRelException;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.logical.LogicalAggregate;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.calcite.util.ImmutableBitSet;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Planner rule that expands GROUPING SETS, ROLLUP, and CUBE into a UNION ALL
+ * of multiple aggregates, each with a single grouping set.
+ * <p>
+ * This rule converts:
+ * SELECT a, b, SUM(c) FROM t GROUP BY GROUPING SETS ((a, b), (a), ())
+ * <p>
+ * Into:
+ * SELECT a, b, SUM(c), 0 AS $g FROM t GROUP BY a, b
+ * UNION ALL
+ * SELECT a, null, SUM(c), 1 AS $g FROM t GROUP BY a
+ * UNION ALL
+ * SELECT null, null, SUM(c), 3 AS $g FROM t GROUP BY ()
+ * <p>
+ * The $g column is the grouping ID that can be used by GROUPING() and
GROUPING_ID() functions.
+ * Currently, the $g column is generated internally but stripped from the
final output.
+ */
+public class DrillAggregateExpandGroupingSetsRule extends RelOptRule {
+
+ public static final DrillAggregateExpandGroupingSetsRule INSTANCE =
+ new DrillAggregateExpandGroupingSetsRule();
+ public static final String GROUPING_ID_COLUMN_NAME = "$g";
+ public static final String GROUP_ID_COLUMN_NAME = "$group_id";
+ public static final String EXPRESSION_COLUMN_PLACEHOLDER = "EXPR$";
+
+ private DrillAggregateExpandGroupingSetsRule() {
+ super(operand(Aggregate.class, any()), DrillRelFactories.LOGICAL_BUILDER,
+ "DrillAggregateExpandGroupingSetsRule");
+ }
+
+ @Override
+ public boolean matches(RelOptRuleCall call) {
+ final Aggregate aggregate = call.rel(0);
+ return aggregate.getGroupSets().size() > 1
+ && (aggregate instanceof DrillAggregateRel || aggregate instanceof
LogicalAggregate);
+ }
+
+ @Override
+ public void onMatch(RelOptRuleCall call) {
+ final Aggregate aggregate = call.rel(0);
+ final RelOptCluster cluster = aggregate.getCluster();
+
+ GroupingFunctionAnalysis analysis =
analyzeGroupingFunctions(aggregate.getAggCallList());
+ GroupingSetOrderingResult ordering =
sortAndAssignGroupIds(aggregate.getGroupSets());
+
+ List<RelNode> perGroupAggregates = new ArrayList<>();
+ for (int i = 0; i < ordering.sortedGroupSets.size(); i++) {
+ perGroupAggregates.add(
+ createAggregateForGroupingSet(call, aggregate,
ordering.sortedGroupSets.get(i),
+ ordering.groupIds.get(i), analysis.regularAggCalls));
+ }
+
+ RelNode unionResult = buildUnion(cluster, perGroupAggregates);
+ RelNode result = buildFinalProject(call, unionResult, aggregate, analysis);
+
+ call.transformTo(result);
+ }
+
+ /**
+ * Encapsulates analysis results of aggregate calls to determine
+ * which are regular aggregates and which are grouping-related
+ * functions (GROUPING, GROUPING_ID, GROUP_ID).
+ */
+ private static class GroupingFunctionAnalysis {
+ final boolean hasGroupingFunctions;
+ final List<AggregateCall> regularAggCalls;
+ final List<AggregateCall> groupingFunctionCalls;
+ final List<Integer> groupingFunctionPositions;
+
+ GroupingFunctionAnalysis(List<AggregateCall> regularAggCalls,
+ List<AggregateCall> groupingFunctionCalls,
+ List<Integer> groupingFunctionPositions) {
+ this.hasGroupingFunctions = !groupingFunctionPositions.isEmpty();
+ this.regularAggCalls = regularAggCalls;
+ this.groupingFunctionCalls = groupingFunctionCalls;
+ this.groupingFunctionPositions = groupingFunctionPositions;
+ }
+ }
+
+ /**
+ * Holds the sorted grouping sets (largest first) and their assigned group
IDs.
+ */
+ private static class GroupingSetOrderingResult {
+ final List<ImmutableBitSet> sortedGroupSets;
+ final List<Integer> groupIds;
+ GroupingSetOrderingResult(List<ImmutableBitSet> sortedGroupSets,
List<Integer> groupIds) {
+ this.sortedGroupSets = sortedGroupSets;
+ this.groupIds = groupIds;
+ }
+ }
+
+ /**
+ * Analyzes aggregate calls to identify which ones are GROUPING-related
functions.
+ *
+ * @param aggCalls list of aggregate calls in the original aggregate
+ * @return structure classifying grouping and non-grouping calls
+ */
+ private GroupingFunctionAnalysis
analyzeGroupingFunctions(List<AggregateCall> aggCalls) {
+ List<AggregateCall> regularAggCalls = new ArrayList<>();
+ List<AggregateCall> groupingFunctionCalls = new ArrayList<>();
+ List<Integer> groupingFunctionPositions = new ArrayList<>();
+
+ for (int i = 0; i < aggCalls.size(); i++) {
+ AggregateCall aggCall = aggCalls.get(i);
+ SqlKind kind = aggCall.getAggregation().getKind();
+ switch (kind) {
+ case GROUPING:
+ case GROUPING_ID:
+ case GROUP_ID:
+ groupingFunctionPositions.add(i);
+ groupingFunctionCalls.add(aggCall);
+ break;
+ default:
+ regularAggCalls.add(aggCall);
+ }
+ }
+
+ return new GroupingFunctionAnalysis(regularAggCalls,
+ groupingFunctionCalls, groupingFunctionPositions);
+ }
+
+ /**
+ * Sorts the given grouping sets in descending order of their cardinality
+ * and assigns group IDs to each grouping set based on their occurrences.
+ *
+ * @param groupSets a list of grouping sets represented as ImmutableBitSet
instances
+ * @return a GroupingSetOrderingResult containing the sorted grouping sets
and their assigned group IDs
+ */
+ private GroupingSetOrderingResult
sortAndAssignGroupIds(List<ImmutableBitSet> groupSets) {
+ List<ImmutableBitSet> sortedGroupSets = new ArrayList<>(groupSets);
+ sortedGroupSets.sort((a, b) -> Integer.compare(b.cardinality(),
a.cardinality()));
+
+ Map<ImmutableBitSet, Integer> groupSetOccurrences = new HashMap<>();
+ List<Integer> groupIds = new ArrayList<>();
+
+ for (ImmutableBitSet groupSet : sortedGroupSets) {
+ int groupId = groupSetOccurrences.getOrDefault(groupSet, 0);
+ groupIds.add(groupId);
+ groupSetOccurrences.put(groupSet, groupId + 1);
+ }
+
+ return new GroupingSetOrderingResult(sortedGroupSets, groupIds);
+ }
+
+ /**
+ * Creates a new aggregate relational node for a specific grouping set. This
method constructs
+ * the necessary aggregation logic and ensures proper handling of grouping
columns, aggregate
+ * calls, and additional metadata such as grouping and group IDs.
+ *
+ * @param call the RelOptRuleCall instance being processed
+ * @param originalAgg the original aggregate relational expression
+ * @param groupSet the grouping set to be handled in the new aggregate
+ * @param groupId the unique identifier associated with this specific
grouping set
+ * @param regularAggCalls the list of regular aggregate calls to be included
in the aggregate
+ * @return a RelNode representing the newly created aggregate for the
specified grouping set
+ */
+ private RelNode createAggregateForGroupingSet(
+ RelOptRuleCall call,
+ Aggregate originalAgg,
+ ImmutableBitSet groupSet,
+ int groupId,
+ List<AggregateCall> regularAggCalls) {
+
+ ImmutableBitSet fullGroupSet = originalAgg.getGroupSet();
+ RelOptCluster cluster = originalAgg.getCluster();
+ RexBuilder rexBuilder = cluster.getRexBuilder();
+ RelDataTypeFactory typeFactory = cluster.getTypeFactory();
+ RelNode input = originalAgg.getInput();
+
+ Aggregate newAggregate;
+ if (originalAgg instanceof DrillAggregateRel) {
+ newAggregate = new DrillAggregateRel(cluster, originalAgg.getTraitSet(),
input,
+ groupSet, ImmutableList.of(groupSet), regularAggCalls);
+ } else {
+ newAggregate = originalAgg.copy(originalAgg.getTraitSet(), input,
groupSet,
+ ImmutableList.of(groupSet), regularAggCalls);
+ }
+
+ List<RexNode> projects = new ArrayList<>();
+ List<String> fieldNames = new ArrayList<>();
+ int aggOutputIdx = 0;
+ int outputColIdx = 0;
+
+ // Populate grouping columns (nulls for omitted columns)
+ for (int col : fullGroupSet) {
+ if (groupSet.get(col)) {
+ projects.add(rexBuilder.makeInputRef(newAggregate, aggOutputIdx++));
+ } else {
+ RelDataType nullType =
originalAgg.getRowType().getFieldList().get(outputColIdx).getType();
+ projects.add(rexBuilder.makeNullLiteral(nullType));
+ }
+
fieldNames.add(originalAgg.getRowType().getFieldList().get(outputColIdx++).getName());
+ }
+
+ // Add regular aggregates
+ for (AggregateCall regCall : regularAggCalls) {
+ projects.add(rexBuilder.makeInputRef(newAggregate, aggOutputIdx++));
+ fieldNames.add(regCall.getName() != null ? regCall.getName() : "agg$" +
aggOutputIdx);
+ }
+
+ // Add grouping ID ($g)
+ int groupingId = computeGroupingId(fullGroupSet, groupSet);
+ projects.add(rexBuilder.makeLiteral(groupingId,
+ typeFactory.createSqlType(SqlTypeName.INTEGER), true));
+ fieldNames.add(GROUPING_ID_COLUMN_NAME);
+
+ // Add group ID ($group_id)
+ projects.add(rexBuilder.makeLiteral(groupId,
+ typeFactory.createSqlType(SqlTypeName.INTEGER), true));
+ fieldNames.add(GROUP_ID_COLUMN_NAME);
+
+ return call.builder().push(newAggregate).project(projects, fieldNames,
false).build();
+ }
+
+ private int computeGroupingId(ImmutableBitSet fullGroupSet, ImmutableBitSet
groupSet) {
+ int id = 0;
+ int bit = 0;
+ for (int col : fullGroupSet) {
+ if (!groupSet.get(col)) {
+ id |= (1 << bit);
+ }
+ bit++;
+ }
+ return id;
+ }
+
+ /**
+ * Builds a union of the given aggregate relational nodes. If there is only
one
+ * aggregate node, it returns that node directly. Otherwise, it constructs a
+ * union relational expression containing all the provided aggregate nodes.
+ *
+ * @param cluster the optimization cluster in which the relational node
resides
+ * @param aggregates a list of aggregate relational nodes to be combined
into a union
+ * @return the resultant union relational node if multiple nodes are
provided;
+ * otherwise, the single aggregate node from the input list
+ * @throws RuntimeException if union creation fails due to invalid
relational state
+ */
+ private RelNode buildUnion(RelOptCluster cluster, List<RelNode> aggregates) {
+ if (aggregates.size() == 1) {
+ return aggregates.get(0);
+ }
+ try {
+ List<RelNode> convertedInputs = new ArrayList<>();
+ for (RelNode agg : aggregates) {
+ convertedInputs.add(convert(agg,
agg.getTraitSet().plus(DrillRel.DRILL_LOGICAL).simplify()));
+ }
+ return new DrillUnionRel(cluster,
+ cluster.traitSet().plus(DrillRel.DRILL_LOGICAL),
+ convertedInputs,
+ true,
+ true,
+ true);
+ } catch (InvalidRelException e) {
+ throw new RuntimeException("Failed to create DrillUnionRel", e);
+ }
+ }
+
+ /**
+ * Builds the final projection for the result of the aggregation,
incorporating the necessary
+ * output columns such as the grouping functions and the aggregation results.
+ *
+ * @param call the RelOptRuleCall instance being processed
+ * @param unionResult the relational expression resulting from the union of
partial aggregates
+ * @param aggregate the original Aggregate relational expression
+ * @param analysis the analysis results classifying regular and
grouping-related aggregate calls
+ * @return the relational expression with the final projection applied
+ */
+ private RelNode buildFinalProject(
+ RelOptRuleCall call,
+ RelNode unionResult,
+ Aggregate aggregate,
+ GroupingFunctionAnalysis analysis) {
+
+ RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
+ RelDataTypeFactory typeFactory = aggregate.getCluster().getTypeFactory();
+ ImmutableBitSet fullGroupSet = aggregate.getGroupSet();
+ List<RexNode> finalProjects = new ArrayList<>();
+ List<String> finalFieldNames = new ArrayList<>();
+ int numFields = unionResult.getRowType().getFieldCount();
+
+ for (int i = 0; i < fullGroupSet.cardinality(); i++) {
+ finalProjects.add(rexBuilder.makeInputRef(unionResult, i));
+
finalFieldNames.add(unionResult.getRowType().getFieldList().get(i).getName());
+ }
+
+ if (analysis.hasGroupingFunctions) {
+ RexNode gColumnRef = rexBuilder.makeInputRef(unionResult, numFields - 2);
+ RexNode groupIdColumnRef = rexBuilder.makeInputRef(unionResult,
numFields - 1);
+ Map<Integer, AggregateCall> groupingFuncMap = new HashMap<>();
+ for (int i = 0; i < analysis.groupingFunctionPositions.size(); i++) {
+ groupingFuncMap.put(analysis.groupingFunctionPositions.get(i),
+ analysis.groupingFunctionCalls.get(i));
+ }
+
+ int regularAggIndex = fullGroupSet.cardinality();
+ for (int origPos = 0; origPos < aggregate.getAggCallList().size();
origPos++) {
+ if (groupingFuncMap.containsKey(origPos)) {
+ AggregateCall groupingCall = groupingFuncMap.get(origPos);
+ String funcName = groupingCall.getAggregation().getName();
+ if ("GROUPING".equals(funcName)) {
+ processGrouping(groupingCall, fullGroupSet, rexBuilder,
typeFactory,
+ gColumnRef, finalProjects, finalFieldNames);
+ } else if ("GROUPING_ID".equals(funcName)) {
+ processGroupingId(groupingCall, fullGroupSet, rexBuilder,
typeFactory,
+ gColumnRef, finalProjects, finalFieldNames);
+ } else if ("GROUP_ID".equals(funcName)) {
+ finalProjects.add(groupIdColumnRef);
+ String fieldName = groupingCall.getName() != null
+ ? groupingCall.getName()
+ : EXPRESSION_COLUMN_PLACEHOLDER + finalFieldNames.size();
+ finalFieldNames.add(fieldName);
+ }
+ } else {
+ finalProjects.add(rexBuilder.makeInputRef(unionResult,
regularAggIndex));
+
finalFieldNames.add(unionResult.getRowType().getFieldList().get(regularAggIndex).getName());
+ regularAggIndex++;
+ }
+ }
+ } else {
+ for (int i = fullGroupSet.cardinality(); i < numFields - 2; i++) {
+ finalProjects.add(rexBuilder.makeInputRef(unionResult, i));
+
finalFieldNames.add(unionResult.getRowType().getFieldList().get(i).getName());
+ }
+ }
+
+ return call.builder().push(unionResult).project(finalProjects,
finalFieldNames, false).build();
+ }
+
+ /**
+ * Processes the GROUPING aggregate function by extracting the bit
representing
+ * whether each column is aggregated or not and appends the computed RexNode
+ * projection and the corresponding field name to the provided lists.
+ *
+ * @param groupingCall the GROUPING aggregate function call to process
+ * @param fullGroupSet the complete set of grouping keys for the aggregation
+ * @param rexBuilder the RexBuilder instance used to construct RexNode
expressions
+ * @param typeFactory the data type factory used for creating type-specific
literals
+ * @param gColumnRef the RexNode reference to the grouping column
+ * @param finalProjects the list to store the constructed RexNode projections
+ * @param finalFieldNames the list to store the corresponding field names
+ */
+ private void processGrouping(AggregateCall groupingCall,
+ ImmutableBitSet fullGroupSet,
+ RexBuilder rexBuilder,
+ RelDataTypeFactory typeFactory,
+ RexNode gColumnRef,
+ List<RexNode> finalProjects,
+ List<String> finalFieldNames) {
+
+ if (groupingCall.getArgList().size() != 1) {
+ throw new RuntimeException("GROUPING() expects exactly 1 argument");
+ }
+
+ int columnIndex = groupingCall.getArgList().get(0);
+ int bitPosition = 0;
+ for (int col : fullGroupSet) {
+ if (col == columnIndex) {
+ break;
+ }
+ bitPosition++;
+ }
+
+ RexNode divisor = rexBuilder.makeLiteral(
+ 1 << bitPosition, typeFactory.createSqlType(SqlTypeName.INTEGER),
true);
+
+ RexNode divided = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE,
gColumnRef, divisor);
+ RexNode extractBit = rexBuilder.makeCall(SqlStdOperatorTable.MOD, divided,
+ rexBuilder.makeLiteral(2,
typeFactory.createSqlType(SqlTypeName.INTEGER), true));
+
+ finalProjects.add(extractBit);
+ String fieldName = groupingCall.getName() != null
+ ? groupingCall.getName()
+ : "EXPR$" + finalFieldNames.size();
+ finalFieldNames.add(fieldName);
+ }
+
+ /**
+ * Processes the GROUPING_ID aggregate function by computing a bitmask
+ * based on the provided grouping columns and full group set. Constructs
+ * the corresponding RexNode representation for the GROUPING_ID function
+ * and appends it to the final projection and field names list.
+ *
+ * @param groupingCall the GROUPING_ID aggregate function call to process
+ * @param fullGroupSet the complete set of grouping keys for the aggregation
+ * @param rexBuilder the RexBuilder instance used to construct RexNode
expressions
+ * @param typeFactory the data type factory for creating type-specific
literals
+ * @param gColumnRef the RexNode reference to the group column
+ * @param finalProjects the list to which the computed RexNode is added
+ * @param finalFieldNames the list to which the corresponding field name is
added
+ */
+ private void processGroupingId(AggregateCall groupingCall,
+ ImmutableBitSet fullGroupSet,
+ RexBuilder rexBuilder,
+ RelDataTypeFactory typeFactory,
+ RexNode gColumnRef,
+ List<RexNode> finalProjects,
+ List<String> finalFieldNames) {
+
+ if (groupingCall.getArgList().isEmpty()) {
+ throw new RuntimeException("GROUPING_ID() expects at least one
argument");
+ }
+
+ RexNode result = null;
+ for (int i = 0; i < groupingCall.getArgList().size(); i++) {
+ int columnIndex = groupingCall.getArgList().get(i);
+ int bitPosition = 0;
+ for (int col : fullGroupSet) {
+ if (col == columnIndex) {
+ break;
+ }
+ bitPosition++;
+ }
+
+ RexNode divisor = rexBuilder.makeLiteral(1 << bitPosition,
+ typeFactory.createSqlType(SqlTypeName.INTEGER), true);
+
+ RexNode divided = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE,
gColumnRef, divisor);
+ RexNode extractBit = rexBuilder.makeCall(SqlStdOperatorTable.MOD,
divided,
+ rexBuilder.makeLiteral(2,
typeFactory.createSqlType(SqlTypeName.INTEGER), true));
+
+ int resultBitPos = groupingCall.getArgList().size() - 1 - i;
+ RexNode bitInPosition = (resultBitPos > 0)
+ ? rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, extractBit,
+ rexBuilder.makeLiteral(1 << resultBitPos,
+ typeFactory.createSqlType(SqlTypeName.INTEGER), true))
+ : extractBit;
+
+ result = (result == null)
+ ? bitInPosition
+ : rexBuilder.makeCall(SqlStdOperatorTable.PLUS, result,
bitInPosition);
+ }
+
+ finalProjects.add(result);
+ String fieldName = groupingCall.getName() != null
+ ? groupingCall.getName()
+ : "EXPR$" + finalFieldNames.size();
+ finalFieldNames.add(fieldName);
+ }
+}
diff --git
a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillAggregateRule.java
b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillAggregateRule.java
index 8ce07dd229..86d67a866c 100644
---
a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillAggregateRule.java
+++
b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillAggregateRule.java
@@ -49,6 +49,12 @@ public class DrillAggregateRule extends RelOptRule {
return;
}
+ if (aggregate.getGroupSets().size() > 1) {
+ // Don't convert aggregates with multiple grouping sets (GROUPING
SETS/ROLLUP/CUBE) to DrillAggregateRel
+ // These should be expanded into UNION ALL by
DrillAggregateExpandGroupingSetsRule first
+ return;
+ }
+
final RelTraitSet traits =
aggregate.getTraitSet().plus(DrillRel.DRILL_LOGICAL);
final RelNode convertedInput = convert(input,
input.getTraitSet().plus(DrillRel.DRILL_LOGICAL).simplify());
call.transformTo(new DrillAggregateRel(aggregate.getCluster(), traits,
convertedInput,
diff --git
a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillOptiq.java
b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillOptiq.java
index 9cf1b26d45..349ba2a02f 100644
---
a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillOptiq.java
+++
b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillOptiq.java
@@ -917,6 +917,11 @@ public class DrillOptiq {
return (ValueExpressions.getIntervalDay(((BigDecimal)
(literal.getValue())).longValue()));
case NULL:
return NullExpression.INSTANCE;
+ case UNKNOWN:
+ // UNKNOWN type is used for NULL literals where the type should be
inferred later
+ // This is used by GROUPING SETS expansion where NULL placeholders
need type inference
+ // from the other branch of UNION ALL
+ return NullExpression.INSTANCE;
case ANY:
if (isLiteralNull(literal)) {
return NullExpression.INSTANCE;
diff --git
a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillUnionAllRule.java
b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillUnionAllRule.java
index a35d320bc2..18719b5447 100644
---
a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillUnionAllRule.java
+++
b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillUnionAllRule.java
@@ -57,9 +57,25 @@ public class DrillUnionAllRule extends RelOptRule {
final RelNode convertedInput = convert(input,
input.getTraitSet().plus(DrillRel.DRILL_LOGICAL).simplify());
convertedInputs.add(convertedInput);
}
+
+ // Detect if this union is from GROUPING SETS expansion by checking if ANY
input has a $g column
+ // The $g column is the grouping ID that we add during expansion
+ // Check all inputs because the union tree may be built incrementally
(binary tree structure)
+ boolean isGroupingSetsExpansion = false;
+ for (RelNode input : convertedInputs) {
+ org.apache.calcite.rel.type.RelDataType inputType = input.getRowType();
+ if (inputType.getFieldCount() > 0) {
+ String lastFieldName =
inputType.getFieldList().get(inputType.getFieldCount() - 1).getName();
+ if ("$g".equals(lastFieldName)) {
+ isGroupingSetsExpansion = true;
+ break;
+ }
+ }
+ }
+
try {
call.transformTo(new DrillUnionRel(union.getCluster(), traits,
convertedInputs, union.all,
- true /* check compatibility */));
+ true /* check compatibility */, isGroupingSetsExpansion));
} catch (InvalidRelException e) {
tracer.warn(e.toString());
}
diff --git
a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillUnionRel.java
b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillUnionRel.java
index 263266e1ff..7d01ec110e 100644
---
a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillUnionRel.java
+++
b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillUnionRel.java
@@ -37,21 +37,33 @@ import org.apache.calcite.plan.RelTraitSet;
* Union implemented in Drill.
*/
public class DrillUnionRel extends Union implements DrillRel, DrillSetOpRel {
+ private final boolean isGroupingSetsExpansion;
+
/** Creates a DrillUnionRel. */
public DrillUnionRel(RelOptCluster cluster, RelTraitSet traits,
List<RelNode> inputs, boolean all, boolean checkCompatibility) throws
InvalidRelException {
+ this(cluster, traits, inputs, all, checkCompatibility, false);
+ }
+
+ public DrillUnionRel(RelOptCluster cluster, RelTraitSet traits,
+ List<RelNode> inputs, boolean all, boolean checkCompatibility, boolean
isGroupingSetsExpansion) throws InvalidRelException {
super(cluster, traits, inputs, all);
+ this.isGroupingSetsExpansion = isGroupingSetsExpansion;
if (checkCompatibility && !this.isCompatible(getRowType(), getInputs())) {
throw new InvalidRelException("Input row types of the Union are not
compatible.");
}
}
+ public boolean isGroupingSetsExpansion() {
+ return isGroupingSetsExpansion;
+ }
+
@Override
public DrillUnionRel copy(RelTraitSet traitSet, List<RelNode> inputs,
boolean all) {
try {
return new DrillUnionRel(getCluster(), traitSet, inputs, all,
- false /* don't check compatibility during copy */);
+ false /* don't check compatibility during copy */,
isGroupingSetsExpansion);
} catch (InvalidRelException e) {
throw new AssertionError(e);
}
diff --git
a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/HashAggPrule.java
b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/HashAggPrule.java
index a61db99510..1b3805e5bd 100644
---
a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/HashAggPrule.java
+++
b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/HashAggPrule.java
@@ -64,6 +64,12 @@ public class HashAggPrule extends AggPruleBase {
return;
}
+ if (aggregate.getGroupSets().size() > 1) {
+ // Don't use HashAggregate for aggregates with multiple grouping sets
(GROUPING SETS/ROLLUP/CUBE)
+ // These should be expanded into UNION ALL by
DrillAggregateExpandGroupingSetsRule first
+ return;
+ }
+
RelTraitSet traits;
try {
diff --git
a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/StreamAggPrule.java
b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/StreamAggPrule.java
index cbed109d19..49a4e93539 100644
---
a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/StreamAggPrule.java
+++
b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/StreamAggPrule.java
@@ -64,6 +64,12 @@ public class StreamAggPrule extends AggPruleBase {
return;
}
+ if (aggregate.getGroupSets().size() > 1) {
+ // Don't use StreamingAggregate for aggregates with multiple grouping
sets (GROUPING SETS/ROLLUP/CUBE)
+ // These should be expanded into UNION ALL by
DrillAggregateExpandGroupingSetsRule first
+ return;
+ }
+
try {
if (aggregate.getGroupSet().isEmpty()) {
DrillDistributionTrait singleDist = DrillDistributionTrait.SINGLETON;
diff --git
a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/UnionAllPrel.java
b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/UnionAllPrel.java
index 460346fa11..a60c7fc276 100644
---
a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/UnionAllPrel.java
+++
b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/UnionAllPrel.java
@@ -39,17 +39,24 @@ import com.google.common.collect.Lists;
public class UnionAllPrel extends UnionPrel {
+ private final boolean isGroupingSetsExpansion;
+
public UnionAllPrel(RelOptCluster cluster, RelTraitSet traits, List<RelNode>
inputs)
throws InvalidRelException {
- super(cluster, traits, inputs, true /* all */);
+ this(cluster, traits, inputs, false);
+ }
+ public UnionAllPrel(RelOptCluster cluster, RelTraitSet traits, List<RelNode>
inputs, boolean isGroupingSetsExpansion)
+ throws InvalidRelException {
+ super(cluster, traits, inputs, true /* all */);
+ this.isGroupingSetsExpansion = isGroupingSetsExpansion;
}
@Override
public Union copy(RelTraitSet traitSet, List<RelNode> inputs, boolean all) {
try {
- return new UnionAllPrel(this.getCluster(), traitSet, inputs);
+ return new UnionAllPrel(this.getCluster(), traitSet, inputs,
isGroupingSetsExpansion);
}catch (InvalidRelException e) {
throw new AssertionError(e);
}
@@ -78,7 +85,7 @@ public class UnionAllPrel extends UnionPrel {
inputPops.add(
((Prel)this.getInputs().get(i)).getPhysicalOperator(creator));
}
- UnionAll unionall = new UnionAll(inputPops);
+ UnionAll unionall = new UnionAll(inputPops, isGroupingSetsExpansion);
return creator.addMetadata(this, unionall);
}
diff --git
a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/UnionAllPrule.java
b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/UnionAllPrule.java
index 8e094604cf..40ebe7f1cd 100644
---
a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/UnionAllPrule.java
+++
b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/UnionAllPrule.java
@@ -101,7 +101,8 @@ public class UnionAllPrule extends Prule {
Preconditions.checkArgument(convertedInputList.size() >= 2, "Union list
must be at least two items.");
RelNode left = convertedInputList.get(0);
for (int i = 1; i < convertedInputList.size(); i++) {
- left = new UnionAllPrel(union.getCluster(), traits,
ImmutableList.of(left, convertedInputList.get(i)));
+ left = new UnionAllPrel(union.getCluster(), traits,
ImmutableList.of(left, convertedInputList.get(i)),
+ union.isGroupingSetsExpansion());
}
call.transformTo(left);
diff --git
a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/parser/UnsupportedOperatorsVisitor.java
b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/parser/UnsupportedOperatorsVisitor.java
index f433308ac2..680e3ca391 100644
---
a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/parser/UnsupportedOperatorsVisitor.java
+++
b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/parser/UnsupportedOperatorsVisitor.java
@@ -17,10 +17,21 @@
*/
package org.apache.drill.exec.planner.sql.parser;
+import com.google.common.collect.Lists;
+import org.apache.calcite.sql.SqlCall;
+import org.apache.calcite.sql.SqlDataTypeSpec;
+import org.apache.calcite.sql.SqlIdentifier;
+import org.apache.calcite.sql.SqlJoin;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNumericLiteral;
-import org.apache.calcite.sql.SqlOperator;
-import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql.SqlSelect;
+import org.apache.calcite.sql.SqlSelectKeyword;
+import org.apache.calcite.sql.SqlWindow;
+import org.apache.calcite.sql.fun.SqlCountAggFunction;
+import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.util.SqlBasicVisitor;
+import org.apache.calcite.sql.util.SqlShuttle;
import org.apache.calcite.util.Litmus;
import org.apache.drill.exec.ExecConstants;
import org.apache.drill.exec.exception.UnsupportedOperatorCollector;
@@ -28,23 +39,8 @@ import org.apache.drill.exec.ops.QueryContext;
import org.apache.drill.exec.planner.physical.PlannerSettings;
import org.apache.drill.exec.work.foreman.SqlUnsupportedException;
-import org.apache.calcite.sql.SqlSelectKeyword;
-import org.apache.calcite.sql.SqlIdentifier;
-import org.apache.calcite.sql.SqlSelect;
-import org.apache.calcite.sql.SqlWindow;
-import org.apache.calcite.sql.fun.SqlCountAggFunction;
-import org.apache.calcite.sql.SqlCall;
-import org.apache.calcite.sql.SqlKind;
-import org.apache.calcite.sql.SqlJoin;
-import org.apache.calcite.sql.SqlNode;
-import org.apache.calcite.sql.type.SqlTypeName;
-import org.apache.calcite.sql.util.SqlShuttle;
-import org.apache.calcite.sql.SqlDataTypeSpec;
-
import java.util.List;
-import com.google.common.collect.Lists;
-
public class UnsupportedOperatorsVisitor extends SqlShuttle {
private QueryContext context;
private static List<String> disabledType = Lists.newArrayList();
@@ -97,10 +93,6 @@ public class UnsupportedOperatorsVisitor extends SqlShuttle {
if (sqlCall instanceof SqlSelect) {
SqlSelect sqlSelect = (SqlSelect) sqlCall;
- checkGrouping((sqlSelect));
-
- checkRollupCubeGrpSets(sqlSelect);
-
for (SqlNode nodeInSelectList : sqlSelect.getSelectList()) {
// If the window function is used with an alias,
// enter the first operand of AS operator
@@ -358,27 +350,6 @@ public class UnsupportedOperatorsVisitor extends
SqlShuttle {
return sqlCall.getOperator().acceptCall(this, sqlCall);
}
- private void checkRollupCubeGrpSets(SqlSelect sqlSelect) {
- final ExprFinder rollupCubeGrpSetsFinder = new
ExprFinder(RollupCubeGrpSets);
- sqlSelect.accept(rollupCubeGrpSetsFinder);
- if (rollupCubeGrpSetsFinder.find()) {
-
unsupportedOperatorCollector.setException(SqlUnsupportedException.ExceptionType.FUNCTION,
- "Rollup, Cube, Grouping Sets are not supported in GROUP BY
clause.\n" +
- "See Apache Drill JIRA: DRILL-3962");
- throw new UnsupportedOperationException();
- }
- }
-
- private void checkGrouping(SqlSelect sqlSelect) {
- final ExprFinder groupingFinder = new ExprFinder(GroupingID);
- sqlSelect.accept(groupingFinder);
- if (groupingFinder.find()) {
-
unsupportedOperatorCollector.setException(SqlUnsupportedException.ExceptionType.FUNCTION,
- "Grouping, Grouping_ID, Group_ID are not supported.\n" +
- "See Apache Drill JIRA: DRILL-3962");
- throw new UnsupportedOperationException();
- }
- }
private boolean checkDirExplorers(SqlNode sqlNode) {
final ExprFinder dirExplorersFinder = new
ExprFinder(DirExplorersCondition);
@@ -401,41 +372,6 @@ public class UnsupportedOperatorsVisitor extends
SqlShuttle {
boolean test(SqlNode sqlNode);
}
- /**
- * A condition that returns true if SqlNode has rollup, cube, grouping_sets.
- * */
- private final SqlNodeCondition RollupCubeGrpSets = new SqlNodeCondition() {
- @Override
- public boolean test(SqlNode sqlNode) {
- if (sqlNode instanceof SqlCall) {
- final SqlOperator operator =
DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(((SqlCall)
sqlNode).getOperator());
- if (operator == SqlStdOperatorTable.ROLLUP
- || operator == SqlStdOperatorTable.CUBE
- || operator == SqlStdOperatorTable.GROUPING_SETS) {
- return true;
- }
- }
- return false;
- }
- };
-
- /**
- * A condition that returns true if SqlNode has Grouping, Grouping_ID,
GROUP_ID.
- */
- private final SqlNodeCondition GroupingID = new SqlNodeCondition() {
- @Override
- public boolean test(SqlNode sqlNode) {
- if (sqlNode instanceof SqlCall) {
- final SqlOperator operator =
DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(((SqlCall)
sqlNode).getOperator());
- if (operator == SqlStdOperatorTable.GROUPING
- || operator == SqlStdOperatorTable.GROUPING_ID
- || operator == SqlStdOperatorTable.GROUP_ID) {
- return true;
- }
- }
- return false;
- }
- };
/**
* A condition that returns true if SqlNode has Directory Explorers.
diff --git
a/exec/java-exec/src/test/java/org/apache/drill/TestDisabledFunctionality.java
b/exec/java-exec/src/test/java/org/apache/drill/TestDisabledFunctionality.java
index b1649fc535..ce7a79c131 100644
---
a/exec/java-exec/src/test/java/org/apache/drill/TestDisabledFunctionality.java
+++
b/exec/java-exec/src/test/java/org/apache/drill/TestDisabledFunctionality.java
@@ -245,75 +245,4 @@ public class TestDisabledFunctionality extends
BaseTestQuery {
resetSessionOption(PlannerSettings.ENABLE_DECIMAL_DATA_TYPE_KEY);
}
}
-
- @Test (expected = UnsupportedFunctionException.class) //DRILL-3802
- public void testDisableRollup() throws Exception{
- try {
- test("select n_regionkey, count(*) as cnt from cp.`tpch/nation.parquet`
group by rollup(n_regionkey, n_name)");
- } catch(UserException ex) {
- throwAsUnsupportedException(ex);
- throw ex;
- }
- }
-
- @Test (expected = UnsupportedFunctionException.class) //DRILL-3802
- public void testDisableCube() throws Exception{
- try {
- test("select n_regionkey, count(*) as cnt from cp.`tpch/nation.parquet`
group by cube(n_regionkey, n_name)");
- } catch(UserException ex) {
- throwAsUnsupportedException(ex);
- throw ex;
- }
- }
-
- @Test (expected = UnsupportedFunctionException.class) //DRILL-3802
- public void testDisableGroupingSets() throws Exception{
- try {
- test("select n_regionkey, count(*) as cnt from cp.`tpch/nation.parquet`
group by grouping sets(n_regionkey, n_name)");
- } catch(UserException ex) {
- throwAsUnsupportedException(ex);
- throw ex;
- }
- }
-
- @Test (expected = UnsupportedFunctionException.class) //DRILL-3802
- public void testDisableGrouping() throws Exception{
- try {
- test("select n_regionkey, count(*), GROUPING(n_regionkey) from
cp.`tpch/nation.parquet` group by n_regionkey;");
- } catch(UserException ex) {
- throwAsUnsupportedException(ex);
- throw ex;
- }
- }
-
- @Test (expected = UnsupportedFunctionException.class) //DRILL-3802
- public void testDisableGrouping_ID() throws Exception{
- try {
- test("select n_regionkey, count(*), GROUPING_ID(n_regionkey) from
cp.`tpch/nation.parquet` group by n_regionkey;");
- } catch(UserException ex) {
- throwAsUnsupportedException(ex);
- throw ex;
- }
- }
-
- @Test (expected = UnsupportedFunctionException.class) //DRILL-3802
- public void testDisableGroup_ID() throws Exception{
- try {
- test("select n_regionkey, count(*), GROUP_ID() from
cp.`tpch/nation.parquet` group by n_regionkey;");
- } catch(UserException ex) {
- throwAsUnsupportedException(ex);
- throw ex;
- }
- }
-
- @Test (expected = UnsupportedFunctionException.class) //DRILL-3802
- public void testDisableGroupingInFilter() throws Exception{
- try {
- test("select n_regionkey, count(*) from cp.`tpch/nation.parquet` group
by n_regionkey HAVING GROUPING(n_regionkey) = 1");
- } catch(UserException ex) {
- throwAsUnsupportedException(ex);
- throw ex;
- }
- }
-
}
diff --git
a/exec/java-exec/src/test/java/org/apache/drill/TestGroupingSetsResults.java
b/exec/java-exec/src/test/java/org/apache/drill/TestGroupingSetsResults.java
new file mode 100644
index 0000000000..2c05e21259
--- /dev/null
+++ b/exec/java-exec/src/test/java/org/apache/drill/TestGroupingSetsResults.java
@@ -0,0 +1,372 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.drill;
+
+import org.apache.drill.categories.OperatorTest;
+import org.apache.drill.categories.SqlTest;
+import org.apache.drill.test.ClusterFixture;
+import org.apache.drill.test.ClusterFixtureBuilder;
+import org.apache.drill.test.ClusterTest;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+
+@Category({SqlTest.class, OperatorTest.class})
+public class TestGroupingSetsResults extends ClusterTest {
+
+ @BeforeClass
+ public static void setUp() throws Exception {
+ ClusterFixtureBuilder builder = ClusterFixture.builder(dirTestWatcher);
+ startCluster(builder);
+ }
+
+ @Test
+ public void testSimpleGroupingSetsResults() throws Exception {
+ String query = "select n_regionkey, count(*) as cnt " +
+ "from cp.`tpch/nation.parquet` " +
+ "group by grouping sets ((n_regionkey), ())";
+
+ testBuilder()
+ .sqlQuery(query)
+ .unOrdered()
+ .baselineColumns("n_regionkey", "cnt")
+ .baselineValues(0, 5L)
+ .baselineValues(1, 5L)
+ .baselineValues(2, 5L)
+ .baselineValues(3, 5L)
+ .baselineValues(4, 5L)
+ .baselineValues(null, 25L) // Grand total
+ .go();
+ }
+
+ @Test
+ public void testRollupResults() throws Exception {
+ // ROLLUP(a, b) creates grouping sets: (a, b), (a), ()
+ String query = "select n_regionkey, count(*) as cnt " +
+ "from cp.`tpch/nation.parquet` " +
+ "where n_regionkey < 2 " +
+ "group by rollup(n_regionkey)";
+
+ testBuilder()
+ .sqlQuery(query)
+ .unOrdered()
+ .baselineColumns("n_regionkey", "cnt")
+ .baselineValues(0, 5L) // Region 0
+ .baselineValues(1, 5L) // Region 1
+ .baselineValues(null, 10L) // Grand total
+ .go();
+ }
+
+ @Test
+ public void testCubeResults() throws Exception {
+ // CUBE(a) creates grouping sets: (a), ()
+ String query = "select n_regionkey, count(*) as cnt " +
+ "from cp.`tpch/nation.parquet` " +
+ "where n_regionkey < 2 " +
+ "group by cube(n_regionkey)";
+
+ testBuilder()
+ .sqlQuery(query)
+ .unOrdered()
+ .baselineColumns("n_regionkey", "cnt")
+ .baselineValues(0, 5L) // Region 0
+ .baselineValues(1, 5L) // Region 1
+ .baselineValues(null, 10L) // Grand total
+ .go();
+ }
+
+ @Test
+ public void testMultiColumnGroupingSets() throws Exception {
+ // Test GROUPING SETS with two columns
+ String query = "select n_regionkey, n_nationkey, count(*) as cnt " +
+ "from cp.`tpch/nation.parquet` " +
+ "where n_regionkey = 0 and n_nationkey in (0, 5) " +
+ "group by grouping sets ((n_regionkey, n_nationkey), (n_regionkey),
())";
+
+ testBuilder()
+ .sqlQuery(query)
+ .unOrdered()
+ .baselineColumns("n_regionkey", "n_nationkey", "cnt")
+ // Grouping set (n_regionkey, n_nationkey)
+ .baselineValues(0, 0, 1L) // Region 0, nation 0
+ .baselineValues(0, 5, 1L) // Region 0, nation 5
+ // Grouping set (n_regionkey)
+ .baselineValues(0, null, 2L) // Region 0 total
+ // Grouping set ()
+ .baselineValues(null, null, 2L) // Grand total
+ .go();
+ }
+
+ @Test
+ public void testRollupTwoColumns() throws Exception {
+ // ROLLUP(a, b) creates grouping sets: (a, b), (a), ()
+ String query = "select n_regionkey, n_nationkey, count(*) as cnt " +
+ "from cp.`tpch/nation.parquet` " +
+ "where n_regionkey = 0 and n_nationkey in (0, 5) " +
+ "group by rollup(n_regionkey, n_nationkey)";
+
+ testBuilder()
+ .sqlQuery(query)
+ .unOrdered()
+ .baselineColumns("n_regionkey", "n_nationkey", "cnt")
+ // Grouping set (n_regionkey, n_nationkey)
+ .baselineValues(0, 0, 1L) // Region 0, nation 0
+ .baselineValues(0, 5, 1L) // Region 0, nation 5
+ // Grouping set (n_regionkey)
+ .baselineValues(0, null, 2L) // Region 0 subtotal
+ // Grouping set ()
+ .baselineValues(null, null, 2L) // Grand total
+ .go();
+ }
+
+ @Test
+ public void testCubeTwoColumns() throws Exception {
+ // CUBE(a, b) creates grouping sets: (a, b), (a), (b), ()
+ // Using specific nations to make the test deterministic
+ String query = "select n_regionkey, n_nationkey, count(*) as cnt " +
+ "from cp.`tpch/nation.parquet` " +
+ "where (n_regionkey = 0 and n_nationkey in (0, 5)) " +
+ " or (n_regionkey = 1 and n_nationkey in (1, 2)) " +
+ "group by cube(n_regionkey, n_nationkey)";
+
+ testBuilder()
+ .sqlQuery(query)
+ .unOrdered()
+ .baselineColumns("n_regionkey", "n_nationkey", "cnt")
+ // Grouping set (n_regionkey, n_nationkey)
+ .baselineValues(0, 0, 1L) // Region 0, nation 0
+ .baselineValues(0, 5, 1L) // Region 0, nation 5
+ .baselineValues(1, 1, 1L) // Region 1, nation 1
+ .baselineValues(1, 2, 1L) // Region 1, nation 2
+ // Grouping set (n_regionkey)
+ .baselineValues(0, null, 2L) // Region 0 total
+ .baselineValues(1, null, 2L) // Region 1 total
+ // Grouping set (n_nationkey)
+ .baselineValues(null, 0, 1L) // Nation 0 across all regions
+ .baselineValues(null, 1, 1L) // Nation 1 across all regions
+ .baselineValues(null, 2, 1L) // Nation 2 across all regions
+ .baselineValues(null, 5, 1L) // Nation 5 across all regions
+ // Grouping set ()
+ .baselineValues(null, null, 4L) // Grand total
+ .go();
+ }
+
+ @Test
+ public void testGroupingSetsWithAggregates() throws Exception {
+ // Test multiple aggregate functions with GROUPING SETS
+ String query = "select n_regionkey, " +
+ "count(*) as cnt, " +
+ "min(n_nationkey) as min_key, " +
+ "max(n_nationkey) as max_key " +
+ "from cp.`tpch/nation.parquet` " +
+ "where n_regionkey < 2 " +
+ "group by grouping sets ((n_regionkey), ())";
+
+ testBuilder()
+ .sqlQuery(query)
+ .unOrdered()
+ .baselineColumns("n_regionkey", "cnt", "min_key", "max_key")
+ .baselineValues(0, 5L, 0, 16) // Region 0
+ .baselineValues(1, 5L, 1, 24) // Region 1
+ .baselineValues(null, 10L, 0, 24) // Grand total
+ .go();
+ }
+
+ @Test
+ public void testGroupingSetsEmptyGroupingSet() throws Exception {
+ // Test just the empty grouping set (grand total only)
+ String query = "select count(*) as cnt, sum(n_nationkey) as sum_key " +
+ "from cp.`tpch/nation.parquet` " +
+ "group by grouping sets (())";
+
+ testBuilder()
+ .sqlQuery(query)
+ .unOrdered()
+ .baselineColumns("cnt", "sum_key")
+ .baselineValues(25L, 300L) // Grand total: 25 nations, sum
0+1+2+...+24 = 300
+ .go();
+ }
+
+ @Test
+ public void testGroupingSetsWithWhere() throws Exception {
+ // Test GROUPING SETS with WHERE clause
+ String query = "select n_regionkey, count(*) as cnt " +
+ "from cp.`tpch/nation.parquet` " +
+ "where n_regionkey in (0, 1, 2) " +
+ "group by grouping sets ((n_regionkey), ())";
+
+ testBuilder()
+ .sqlQuery(query)
+ .unOrdered()
+ .baselineColumns("n_regionkey", "cnt")
+ .baselineValues(0, 5L)
+ .baselineValues(1, 5L)
+ .baselineValues(2, 5L)
+ .baselineValues(null, 15L) // Total of regions 0, 1, 2
+ .go();
+ }
+
+ @Test
+ public void testGroupingSetsWithExpression() throws Exception {
+ // Test GROUPING SETS with computed columns
+ String query = "select n_regionkey, " +
+ "case when n_nationkey < 10 then 'low' else 'high' end as key_range, "
+
+ "count(*) as cnt " +
+ "from cp.`tpch/nation.parquet` " +
+ "where n_regionkey < 2 " +
+ "group by grouping sets (" +
+ " (n_regionkey, case when n_nationkey < 10 then 'low' else 'high'
end), " +
+ " (n_regionkey)" +
+ ")";
+
+ testBuilder()
+ .sqlQuery(query)
+ .unOrdered()
+ .baselineColumns("n_regionkey", "key_range", "cnt")
+ // Grouping set (n_regionkey, key_range)
+ .baselineValues(0, "low", 2L) // Region 0, low keys (0,5)
+ .baselineValues(0, "high", 3L) // Region 0, high keys (14,15,16)
+ .baselineValues(1, "low", 3L) // Region 1, low keys (1,2,3)
+ .baselineValues(1, "high", 2L) // Region 1, high keys (17,24)
+ // Grouping set (n_regionkey)
+ .baselineValues(0, null, 5L) // Region 0 total
+ .baselineValues(1, null, 5L) // Region 1 total
+ .go();
+ }
+
+ @Test
+ public void testRollupWithJSON() throws Exception {
+ // Test ROLLUP with JSON data
+ String query = "select education_level, count(*) as cnt " +
+ "from cp.`employee.json` " +
+ "where education_level in ('Graduate Degree', 'Bachelors Degree',
'Partial College') " +
+ "group by rollup(education_level)";
+
+ // This should now work with proper type handling
+ queryBuilder()
+ .sql(query)
+ .run();
+ }
+
+ // Tests for GROUPING() and GROUPING_ID() functions
+ // These functions help distinguish between NULL values that are actual data
+ // versus NULL values inserted by GROUPING SETS/ROLLUP/CUBE operations.
+
+ @Test
+ public void testGroupingFunction() throws Exception {
+ // Test GROUPING function with ROLLUP
+ // GROUPING returns 1 if the column is aggregated (NULL in output), 0
otherwise
+ String query = "select education_level, " +
+ "GROUPING(education_level) as grp, " +
+ "count(*) as cnt " +
+ "from cp.`employee.json` " +
+ "where education_level in ('Graduate Degree', 'Bachelors Degree') " +
+ "group by rollup(education_level)";
+
+ testBuilder()
+ .sqlQuery(query)
+ .unOrdered()
+ .baselineColumns("education_level", "grp", "cnt")
+ .baselineValues("Graduate Degree", 0, 170L) // Not aggregated: grp=0
+ .baselineValues("Bachelors Degree", 0, 287L) // Not aggregated: grp=0
+ .baselineValues(null, 1, 457L) // Aggregated (grand
total): grp=1
+ .go();
+ }
+
+ @Test
+ public void testGroupingIdFunction() throws Exception {
+ // Test GROUPING_ID function with CUBE
+ // GROUPING_ID returns a bitmap where bit i is 1 if column i is aggregated
+ // For CUBE(marital_status, education_level), we get grouping sets:
+ // (marital_status, education_level), (marital_status), (education_level),
()
+ String query = "select marital_status, education_level, " +
+ "GROUPING_ID(marital_status, education_level) as grp_id, " +
+ "count(*) as cnt " +
+ "from cp.`employee.json` " +
+ "where marital_status in ('S', 'M') " +
+ "and education_level in ('Graduate Degree', 'Bachelors Degree') " +
+ "group by cube(marital_status, education_level)";
+
+ testBuilder()
+ .sqlQuery(query)
+ .unOrdered()
+ .baselineColumns("marital_status", "education_level", "grp_id", "cnt")
+ // (marital_status, education_level) - neither aggregated: grp_id = 0
+ .baselineValues("S", "Graduate Degree", 0, 85L)
+ .baselineValues("S", "Bachelors Degree", 0, 143L)
+ .baselineValues("M", "Graduate Degree", 0, 85L)
+ .baselineValues("M", "Bachelors Degree", 0, 144L)
+ // (marital_status) - education_level aggregated: grp_id = 1 (bit 0
set)
+ .baselineValues("S", null, 1, 228L)
+ .baselineValues("M", null, 1, 229L)
+ // (education_level) - marital_status aggregated: grp_id = 2 (bit 1
set)
+ .baselineValues(null, "Graduate Degree", 2, 170L)
+ .baselineValues(null, "Bachelors Degree", 2, 287L)
+ // () - both aggregated: grp_id = 3 (both bits set)
+ .baselineValues(null, null, 3, 457L)
+ .go();
+ }
+
+ @Test
+ public void testGroupIdFunction() throws Exception {
+ // Test GROUP_ID function with duplicate grouping sets
+ // GROUP_ID() returns 0 for first occurrence, 1 for second, etc.
+ String query = "select n_regionkey, " +
+ "GROUP_ID() as grp_id, " +
+ "count(*) as cnt " +
+ "from cp.`tpch/nation.parquet` " +
+ "where n_regionkey < 2 " +
+ "group by grouping sets ((n_regionkey), (n_regionkey), ()) " +
+ "order by grp_id, n_regionkey nulls last";
+
+ testBuilder()
+ .sqlQuery(query)
+ .ordered()
+ .baselineColumns("n_regionkey", "grp_id", "cnt")
+ // First occurrence of (n_regionkey): grp_id = 0
+ .baselineValues(0, 0L, 5L) // Region 0
+ .baselineValues(1, 0L, 5L) // Region 1
+ .baselineValues(null, 0L, 10L) // Empty grouping set
+ // Second occurrence of (n_regionkey): grp_id = 1
+ .baselineValues(0, 1L, 5L) // Region 0
+ .baselineValues(1, 1L, 5L) // Region 1
+ .go();
+ }
+
+ @Test
+ public void testGroupIdNoDuplicates() throws Exception {
+ // Test GROUP_ID when there are no duplicate grouping sets
+ // All GROUP_ID values should be 0
+ String query = "select n_regionkey, " +
+ "GROUP_ID() as grp_id, " +
+ "count(*) as cnt " +
+ "from cp.`tpch/nation.parquet` " +
+ "where n_regionkey < 2 " +
+ "group by grouping sets ((n_regionkey), ()) " +
+ "order by n_regionkey nulls last";
+
+ testBuilder()
+ .sqlQuery(query)
+ .ordered()
+ .baselineColumns("n_regionkey", "grp_id", "cnt")
+ .baselineValues(0, 0L, 5L)
+ .baselineValues(1, 0L, 5L)
+ .baselineValues(null, 0L, 10L)
+ .go();
+ }
+}