This is an automated email from the ASF dual-hosted git repository.
jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new e2c5e73970 Pass literal within AggregateCall via rexList (#13282)
e2c5e73970 is described below
commit e2c5e73970b1e8f64df7c763c5bcac36ff19d2a6
Author: Xiaotian (Jackie) Jiang <[email protected]>
AuthorDate: Fri May 31 18:00:00 2024 -0700
Pass literal within AggregateCall via rexList (#13282)
---
.../pinot/calcite/rel/hint/PinotHintOptions.java | 13 -
.../PinotAggregateExchangeNodeInsertRule.java | 422 ++++++++++-----------
.../rules/PinotAggregateLiteralAttachmentRule.java | 107 ------
.../calcite/rel/rules/PinotQueryRuleSets.java | 5 -
.../org/apache/pinot/query/QueryEnvironment.java | 4 -
.../query/parser/CalciteRexExpressionParser.java | 4 +-
.../query/planner/logical/LiteralHintUtils.java | 85 -----
.../query/planner/logical/RexExpressionUtils.java | 6 +-
.../apache/pinot/query/QueryCompilationTest.java | 3 +-
.../src/test/resources/queries/GroupByPlans.json | 18 +-
.../src/test/resources/queries/OrderByPlans.json | 4 +-
.../test/resources/queries/PinotHintablePlans.json | 33 +-
.../query/runtime/operator/AggregateOperator.java | 125 ++----
.../src/test/resources/queries/QueryHints.json | 8 +-
.../pinot/segment/spi/AggregationFunctionType.java | 7 +-
15 files changed, 256 insertions(+), 588 deletions(-)
diff --git
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java
index 1d53a3184e..99e07b61df 100644
---
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java
+++
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java
@@ -20,7 +20,6 @@ package org.apache.pinot.calcite.rel.hint;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.hint.RelHint;
-import org.apache.pinot.query.planner.logical.LiteralHintUtils;
/**
@@ -47,18 +46,6 @@ public class PinotHintOptions {
public static class InternalAggregateOptions {
public static final String AGG_TYPE = "agg_type";
- /**
- * agg call signature is used to store LITERAL inputs to the Aggregate
Call. which is not supported in Calcite
- * here
- * 1. we store the Map of Pair[aggCallIdx, argListIdx] to RexLiteral to
indicate the RexLiteral being passed into
- * the aggregateCalls[aggCallIdx].operandList[argListIdx] is supposed
to be a RexLiteral.
- * 2. not all RexLiteral types are supported to be part of the input
constant call signature.
- * 3. RexLiteral are encoded as String and decoded as Pinot Literal
objects.
- *
- * see: {@link LiteralHintUtils}.
- * see: https://issues.apache.org/jira/projects/CALCITE/issues/CALCITE-5833
- */
- public static final String AGG_CALL_SIGNATURE = "agg_call_signature";
}
public static class AggregateOptions {
diff --git
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
index ffe0741751..0e6e13b0e7 100644
---
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
+++
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
@@ -19,20 +19,16 @@
package org.apache.pinot.calcite.rel.rules;
import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
-import java.util.Collections;
import java.util.HashMap;
import java.util.List;
-import java.util.Locale;
import java.util.Map;
-import java.util.Set;
import javax.annotation.Nullable;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
-import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelCollations;
+import org.apache.calcite.rel.RelDistribution;
import org.apache.calcite.rel.RelDistributions;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.RelNode;
@@ -44,16 +40,16 @@ import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.rules.AggregateExtractProjectRule;
import org.apache.calcite.rel.rules.AggregateReduceFunctionsRule;
import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
-import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
-import org.apache.calcite.util.Util;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.MappingType;
import org.apache.calcite.util.mapping.Mappings;
@@ -88,8 +84,6 @@ import org.apache.pinot.segment.spi.AggregationFunctionType;
public class PinotAggregateExchangeNodeInsertRule extends RelOptRule {
public static final PinotAggregateExchangeNodeInsertRule INSTANCE =
new
PinotAggregateExchangeNodeInsertRule(PinotRuleUtils.PINOT_REL_FACTORY);
- public static final Set<String> LIST_AGG_FUNCTION_NAMES =
- ImmutableSet.of("LISTAGG", "LIST_AGG", "ARRAYsAGG", "ARRAY_AGG");
public PinotAggregateExchangeNodeInsertRule(RelBuilderFactory factory) {
super(operand(LogicalAggregate.class, any()), factory, null);
@@ -119,137 +113,104 @@ public class PinotAggregateExchangeNodeInsertRule
extends RelOptRule {
*/
@Override
public void onMatch(RelOptRuleCall call) {
- Aggregate oldAggRel = call.rel(0);
- ImmutableList<RelHint> oldHints = oldAggRel.getHints();
- // Both collation and distinct are not supported in leaf stage aggregation.
- boolean hasCollation = hasCollation(oldAggRel);
- boolean hasDistinct = hasDistinct(oldAggRel);
- Aggregate newAgg;
- if (!oldAggRel.getGroupSet().isEmpty() &&
PinotHintStrategyTable.isHintOptionTrue(oldHints,
- PinotHintOptions.AGGREGATE_HINT_OPTIONS,
PinotHintOptions.AggregateOptions.IS_PARTITIONED_BY_GROUP_BY_KEYS)) {
- //
------------------------------------------------------------------------
- // If the "is_partitioned_by_group_by_keys" aggregate hint option is
set, just add additional hints indicating
- // this is a single stage aggregation.
- List<RelHint> newHints =
PinotHintStrategyTable.replaceHintOptions(oldAggRel.getHints(),
- PinotHintOptions.INTERNAL_AGG_OPTIONS,
PinotHintOptions.InternalAggregateOptions.AGG_TYPE,
- AggType.DIRECT.name());
- newAgg =
- new LogicalAggregate(oldAggRel.getCluster(),
oldAggRel.getTraitSet(), newHints, oldAggRel.getInput(),
- oldAggRel.getGroupSet(), oldAggRel.getGroupSets(),
oldAggRel.getAggCallList());
- } else if (hasCollation || hasDistinct ||
(!oldAggRel.getGroupSet().isEmpty()
- && PinotHintStrategyTable.isHintOptionTrue(oldHints,
PinotHintOptions.AGGREGATE_HINT_OPTIONS,
+ Aggregate argRel = call.rel(0);
+ ImmutableList<RelHint> hints = argRel.getHints();
+ // Collation is not supported in leaf stage aggregation.
+ RelCollation collation = extractWithInGroupCollation(argRel);
+ boolean hasGroupBy = !argRel.getGroupSet().isEmpty();
+ if (collation != null || (hasGroupBy &&
PinotHintStrategyTable.isHintOptionTrue(hints,
+ PinotHintOptions.AGGREGATE_HINT_OPTIONS,
PinotHintOptions.AggregateOptions.SKIP_LEAF_STAGE_GROUP_BY_AGGREGATION))) {
- //
------------------------------------------------------------------------
- // If "is_skip_leaf_stage_group_by" SQLHint option is passed, the leaf
stage aggregation is skipped.
- newAgg = (Aggregate) createPlanWithExchangeDirectAggregation(call);
+ call.transformTo(createPlanWithExchangeDirectAggregation(call,
collation));
+ } else if (hasGroupBy && PinotHintStrategyTable.isHintOptionTrue(hints,
PinotHintOptions.AGGREGATE_HINT_OPTIONS,
+ PinotHintOptions.AggregateOptions.IS_PARTITIONED_BY_GROUP_BY_KEYS)) {
+ call.transformTo(createPlanWithDirectAggregation(call));
} else {
- //
------------------------------------------------------------------------
- newAgg = (Aggregate) createPlanWithLeafExchangeFinalAggregate(call);
+ call.transformTo(createPlanWithLeafExchangeFinalAggregate(call));
}
- call.transformTo(newAgg);
}
- private boolean hasDistinct(Aggregate aggRel) {
+ // TODO: Currently it only handles one WITHIN GROUP collation across all
AggregateCalls.
+ @Nullable
+ private static RelCollation extractWithInGroupCollation(Aggregate aggRel) {
for (AggregateCall aggCall : aggRel.getAggCallList()) {
- // If the aggregation function is a list aggregation function and it is
distinct, we can skip leaf stage.
- // For COUNT(DISTINCT), there could be more leaf stage optimization.
- if (aggCall.isDistinct() &&
LIST_AGG_FUNCTION_NAMES.contains(aggCall.getAggregation().getName().toUpperCase()))
{
- return true;
+ List<RelFieldCollation> fieldCollations =
aggCall.getCollation().getFieldCollations();
+ if (!fieldCollations.isEmpty()) {
+ return RelCollations.of(fieldCollations);
}
}
- return false;
+ return null;
}
- private boolean hasCollation(Aggregate aggRel) {
- for (AggregateCall aggCall : aggRel.getAggCallList()) {
- if (!aggCall.getCollation().getKeys().isEmpty()) {
- return true;
- }
- }
- return false;
+ private static RelNode createPlanWithDirectAggregation(RelOptRuleCall call) {
+ Aggregate aggRel = call.rel(0);
+ List<RelHint> newHints =
+ PinotHintStrategyTable.replaceHintOptions(aggRel.getHints(),
PinotHintOptions.INTERNAL_AGG_OPTIONS,
+ PinotHintOptions.InternalAggregateOptions.AGG_TYPE,
AggType.DIRECT.name());
+ return new LogicalAggregate(aggRel.getCluster(), aggRel.getTraitSet(),
newHints, aggRel.getInput(),
+ aggRel.getGroupSet(), aggRel.getGroupSets(), buildAggCalls(aggRel,
AggType.DIRECT));
}
/**
* Aggregate node will be split into LEAF + exchange + FINAL.
- * optionally we can insert INTERMEDIATE to reduce hotspot in the future.
+ * TODO: Add optional INTERMEDIATE stage to reduce hotspot.
*/
- private RelNode createPlanWithLeafExchangeFinalAggregate(RelOptRuleCall
call) {
- // TODO: add optional intermediate stage here when hinted.
- Aggregate oldAggRel = call.rel(0);
- // 1. attach leaf agg RelHint to original agg. Perform any aggregation
call conversions necessary
- Aggregate leafAgg = convertAggForLeafInput(oldAggRel);
- // 2. attach exchange.
- List<Integer> groupSetIndices = ImmutableIntList.range(0,
oldAggRel.getGroupCount());
- PinotLogicalExchange exchange;
- if (groupSetIndices.size() == 0) {
- exchange = PinotLogicalExchange.create(leafAgg,
RelDistributions.hash(Collections.emptyList()));
- } else {
- exchange = PinotLogicalExchange.create(leafAgg,
RelDistributions.hash(groupSetIndices));
- }
- // 3. attach final agg stage.
- return convertAggFromIntermediateInput(call, oldAggRel, exchange,
AggType.FINAL);
+ private static RelNode
createPlanWithLeafExchangeFinalAggregate(RelOptRuleCall call) {
+ Aggregate aggRel = call.rel(0);
+ // Create a LEAF aggregate.
+ Aggregate leafAggRel = convertAggForLeafInput(aggRel);
+ // Create an exchange node over the LEAF aggregate.
+ PinotLogicalExchange exchange = PinotLogicalExchange.create(leafAggRel,
+ RelDistributions.hash(ImmutableIntList.range(0,
aggRel.getGroupCount())));
+ // Create a FINAL aggregate over the exchange.
+ return convertAggFromIntermediateInput(call, exchange, AggType.FINAL);
}
/**
* Use this group by optimization to skip leaf stage aggregation when
aggregating at leaf level is not desired.
* Many situation could be wasted effort to do group-by on leaf, eg: when
cardinality of group by column is very high.
*/
- private RelNode createPlanWithExchangeDirectAggregation(RelOptRuleCall call)
{
- Aggregate oldAggRel = call.rel(0);
- List<RelHint> newHints =
PinotHintStrategyTable.replaceHintOptions(oldAggRel.getHints(),
- PinotHintOptions.INTERNAL_AGG_OPTIONS,
PinotHintOptions.InternalAggregateOptions.AGG_TYPE,
- AggType.DIRECT.name());
-
- // Convert Aggregate WithGroup Collation into a Sort
- RelCollation relCollation = extractWithInGroupCollation(oldAggRel);
+ private static RelNode
createPlanWithExchangeDirectAggregation(RelOptRuleCall call,
+ @Nullable RelCollation collation) {
+ Aggregate aggRel = call.rel(0);
+ RelNode input = aggRel.getInput();
+ // Create Project when there's none below the aggregate.
+ if (!(PinotRuleUtils.unboxRel(input) instanceof Project)) {
+ aggRel = (Aggregate) generateProjectUnderAggregate(call);
+ input = aggRel.getInput();
+ }
- // create project when there's none below the aggregate to reduce exchange
overhead
- RelNode childRel = ((HepRelVertex) oldAggRel.getInput()).getCurrentRel();
- if (!(childRel instanceof Project)) {
- return convertAggForExchangeDirectAggregate(call, newHints,
relCollation);
+ ImmutableBitSet groupSet = aggRel.getGroupSet();
+ RelDistribution distribution = RelDistributions.hash(groupSet.asList());
+ RelNode exchange;
+ if (collation != null) {
+ // Insert a LogicalSort node between exchange and aggregate whe
collation exists.
+ exchange = PinotLogicalSortExchange.create(input, distribution,
collation, false, true);
} else {
- // create normal exchange
- List<Integer> groupSetIndices = new ArrayList<>();
- oldAggRel.getGroupSet().forEach(groupSetIndices::add);
- RelNode newAggChild;
- if (relCollation != null) {
- newAggChild =
- (groupSetIndices.isEmpty()) ?
PinotLogicalSortExchange.create(childRel, RelDistributions.SINGLETON,
- relCollation, false, true)
- : PinotLogicalSortExchange.create(childRel,
RelDistributions.hash(groupSetIndices),
- relCollation, false, true);
- } else {
- newAggChild = PinotLogicalExchange.create(childRel,
RelDistributions.hash(groupSetIndices));
- }
- return new LogicalAggregate(oldAggRel.getCluster(),
oldAggRel.getTraitSet(), newHints, newAggChild,
- oldAggRel.getGroupSet(), oldAggRel.getGroupSets(),
oldAggRel.getAggCallList());
+ exchange = PinotLogicalExchange.create(input, distribution);
}
- }
- // Extract the first collation in the AggregateCall list
- @Nullable
- private RelCollation extractWithInGroupCollation(Aggregate aggRel) {
- for (AggregateCall aggCall : aggRel.getAggCallList()) {
- List<RelFieldCollation> fieldCollations =
aggCall.getCollation().getFieldCollations();
- if (!fieldCollations.isEmpty()) {
- return RelCollations.of(fieldCollations);
- }
- }
- return null;
+ List<RelHint> newHints =
+ PinotHintStrategyTable.replaceHintOptions(aggRel.getHints(),
PinotHintOptions.INTERNAL_AGG_OPTIONS,
+ PinotHintOptions.InternalAggregateOptions.AGG_TYPE,
AggType.DIRECT.name());
+ return new LogicalAggregate(aggRel.getCluster(), aggRel.getTraitSet(),
newHints, exchange, groupSet,
+ aggRel.getGroupSets(), buildAggCalls(aggRel, AggType.DIRECT));
}
/**
- * The following is copied from {@link
AggregateExtractProjectRule#onMatch(RelOptRuleCall)}
- * with modification to insert an exchange in between the Aggregate and
Project
+ * The following is copied from {@link
AggregateExtractProjectRule#onMatch(RelOptRuleCall)} with modification to take
+ * aggregate input as input.
*/
- private RelNode convertAggForExchangeDirectAggregate(RelOptRuleCall call,
List<RelHint> newHints,
- @Nullable RelCollation collation) {
+ private static RelNode generateProjectUnderAggregate(RelOptRuleCall call) {
final Aggregate aggregate = call.rel(0);
+ // --------------- MODIFIED ---------------
final RelNode input = aggregate.getInput();
+ // final RelNode input = call.rel(1);
+ // ------------- END MODIFIED -------------
+
// Compute which input fields are used.
// 1. group fields are always used
- final ImmutableBitSet.Builder inputFieldsUsed =
- aggregate.getGroupSet().rebuild();
+ final ImmutableBitSet.Builder inputFieldsUsed =
aggregate.getGroupSet().rebuild();
// 2. agg functions
for (AggregateCall aggCall : aggregate.getAggCallList()) {
for (int i : aggCall.getArgList()) {
@@ -259,149 +220,164 @@ public class PinotAggregateExchangeNodeInsertRule
extends RelOptRule {
inputFieldsUsed.set(aggCall.filterArg);
}
}
- final RelBuilder relBuilder1 = call.builder().push(input);
+ final RelBuilder relBuilder = call.builder().push(input);
final List<RexNode> projects = new ArrayList<>();
final Mapping mapping =
- Mappings.create(MappingType.INVERSE_SURJECTION,
- aggregate.getInput().getRowType().getFieldCount(),
+ Mappings.create(MappingType.INVERSE_SURJECTION,
aggregate.getInput().getRowType().getFieldCount(),
inputFieldsUsed.cardinality());
int j = 0;
for (int i : inputFieldsUsed.build()) {
- projects.add(relBuilder1.field(i));
+ projects.add(relBuilder.field(i));
mapping.set(i, j++);
}
- relBuilder1.project(projects);
- final ImmutableBitSet newGroupSet =
- Mappings.apply(mapping, aggregate.getGroupSet());
- Project project = (Project) relBuilder1.build();
- // ------------------------------------------------------------------------
- RelNode newAggChild;
- if (collation != null) {
- // Insert a LogicalSort node between the exchange and the aggregate
- newAggChild = newGroupSet.isEmpty() ?
PinotLogicalSortExchange.create(project, RelDistributions.SINGLETON,
- collation, false, true)
- : PinotLogicalSortExchange.create(project,
RelDistributions.hash(newGroupSet.asList()),
- collation, false, true);
- } else {
- newAggChild = PinotLogicalExchange.create(project,
RelDistributions.hash(newGroupSet.asList()));
- }
- // ------------------------------------------------------------------------
+ relBuilder.project(projects);
- final RelBuilder relBuilder2 = call.builder().push(newAggChild);
+ final ImmutableBitSet newGroupSet = Mappings.apply(mapping,
aggregate.getGroupSet());
final List<ImmutableBitSet> newGroupSets =
- aggregate.getGroupSets().stream()
- .map(bitSet -> Mappings.apply(mapping, bitSet))
- .collect(Util.toImmutableList());
+ aggregate.getGroupSets().stream().map(bitSet ->
Mappings.apply(mapping, bitSet))
+ .collect(ImmutableList.toImmutableList());
final List<RelBuilder.AggCall> newAggCallList =
- aggregate.getAggCallList().stream()
- .map(aggCall -> relBuilder2.aggregateCall(aggCall, mapping))
- .collect(Util.toImmutableList());
- final RelBuilder.GroupKey groupKey =
- relBuilder2.groupKey(newGroupSet, newGroupSets);
- relBuilder2.aggregate(groupKey, newAggCallList).hints(newHints);
- return relBuilder2.build();
+ aggregate.getAggCallList().stream().map(aggCall ->
relBuilder.aggregateCall(aggCall, mapping))
+ .collect(ImmutableList.toImmutableList());
+
+ final RelBuilder.GroupKey groupKey = relBuilder.groupKey(newGroupSet,
newGroupSets);
+ relBuilder.aggregate(groupKey, newAggCallList);
+ return relBuilder.build();
}
- private Aggregate convertAggForLeafInput(Aggregate oldAggRel) {
- List<AggregateCall> oldCalls = oldAggRel.getAggCallList();
- List<AggregateCall> newCalls = new ArrayList<>();
- for (AggregateCall oldCall : oldCalls) {
- newCalls.add(buildAggregateCall(oldAggRel.getInput(), oldCall,
oldCall.getArgList(), oldAggRel.getGroupCount(),
- AggType.LEAF));
- }
- List<RelHint> newHints =
PinotHintStrategyTable.replaceHintOptions(oldAggRel.getHints(),
- PinotHintOptions.INTERNAL_AGG_OPTIONS,
PinotHintOptions.InternalAggregateOptions.AGG_TYPE, AggType.LEAF.name());
- return new LogicalAggregate(oldAggRel.getCluster(),
oldAggRel.getTraitSet(), newHints, oldAggRel.getInput(),
- oldAggRel.getGroupSet(), oldAggRel.getGroupSets(), newCalls);
+ private static Aggregate convertAggForLeafInput(Aggregate aggRel) {
+ List<RelHint> newHints =
+ PinotHintStrategyTable.replaceHintOptions(aggRel.getHints(),
PinotHintOptions.INTERNAL_AGG_OPTIONS,
+ PinotHintOptions.InternalAggregateOptions.AGG_TYPE,
AggType.LEAF.name());
+ return new LogicalAggregate(aggRel.getCluster(), aggRel.getTraitSet(),
newHints, aggRel.getInput(),
+ aggRel.getGroupSet(), aggRel.getGroupSets(), buildAggCalls(aggRel,
AggType.LEAF));
}
- private RelNode convertAggFromIntermediateInput(RelOptRuleCall ruleCall,
Aggregate oldAggRel,
- PinotLogicalExchange exchange, AggType aggType) {
- // add the exchange as the input node to the relation builder.
- RelBuilder relBuilder = ruleCall.builder();
- relBuilder.push(exchange);
+ private static RelNode convertAggFromIntermediateInput(RelOptRuleCall call,
PinotLogicalExchange exchange,
+ AggType aggType) {
+ Aggregate aggRel = call.rel(0);
+ RelNode input = PinotRuleUtils.unboxRel(aggRel.getInput());
+ List<RexNode> projects = (input instanceof Project) ? ((Project)
input).getProjects() : null;
- // make input ref to the exchange after the leaf aggregate, all groups
should be at the front
RexBuilder rexBuilder = exchange.getCluster().getRexBuilder();
- final int nGroups = oldAggRel.getGroupCount();
- for (int i = 0; i < nGroups; i++) {
- rexBuilder.makeInputRef(oldAggRel, i);
- }
-
- List<AggregateCall> newCalls = new ArrayList<>();
+ int groupCount = aggRel.getGroupCount();
+ List<AggregateCall> orgAggCalls = aggRel.getAggCallList();
+ int numAggCalls = orgAggCalls.size();
+ List<AggregateCall> aggCalls = new ArrayList<>(numAggCalls);
Map<AggregateCall, RexNode> aggCallMapping = new HashMap<>();
- // create new aggregate function calls from exchange input, all aggCalls
are followed one by one from exchange
- // b/c the exchange produces intermediate results, thus the input to the
newCall will be indexed at
- // [nGroup + oldCallIndex]
- List<AggregateCall> oldCalls = oldAggRel.getAggCallList();
- for (int oldCallIndex = 0; oldCallIndex < oldCalls.size(); oldCallIndex++)
{
- AggregateCall oldCall = oldCalls.get(oldCallIndex);
- // intermediate stage input only supports single argument inputs.
- List<Integer> argList = Collections.singletonList(nGroups +
oldCallIndex);
- AggregateCall newCall = buildAggregateCall(exchange, oldCall, argList,
nGroups, aggType);
- rexBuilder.addAggCall(newCall, nGroups, newCalls, aggCallMapping,
oldAggRel.getInput()::fieldIsNullable);
+ // Create new AggregateCalls from exchange input. Exchange produces
results with group keys followed by intermediate
+ // aggregate results.
+ for (int i = 0; i < numAggCalls; i++) {
+ AggregateCall orgAggCall = orgAggCalls.get(i);
+ List<Integer> argList = orgAggCall.getArgList();
+ int index = groupCount + i;
+ RexInputRef inputRef = RexInputRef.of(index, aggRel.getRowType());
+ // Generate rexList from argList and replace literal reference with
literal. Keep the first argument as is.
+ int numArguments = argList.size();
+ List<RexNode> rexList;
+ if (numArguments <= 1) {
+ rexList = ImmutableList.of(inputRef);
+ } else {
+ rexList = new ArrayList<>(numArguments);
+ rexList.add(inputRef);
+ for (int j = 1; j < numArguments; j++) {
+ int argument = argList.get(j);
+ if (projects != null && projects.get(argument) instanceof
RexLiteral) {
+ rexList.add(projects.get(argument));
+ } else {
+ // Replace all the input reference in the rexList to the new input
reference.
+ rexList.add(inputRef);
+ }
+ }
+ }
+ AggregateCall newAggregate = buildAggCall(exchange, orgAggCall, rexList,
groupCount, aggType);
+ rexBuilder.addAggCall(newAggregate, groupCount, aggCalls,
aggCallMapping, aggRel.getInput()::fieldIsNullable);
}
- // create new aggregate relation.
- ImmutableList<RelHint> orgHints = oldAggRel.getHints();
- List<RelHint> newAggHint =
PinotHintStrategyTable.replaceHintOptions(orgHints,
- PinotHintOptions.INTERNAL_AGG_OPTIONS,
PinotHintOptions.InternalAggregateOptions.AGG_TYPE, aggType.name());
- ImmutableBitSet groupSet = ImmutableBitSet.range(nGroups);
- relBuilder.aggregate(relBuilder.groupKey(groupSet,
ImmutableList.of(groupSet)), newCalls);
- relBuilder.hints(newAggHint);
+ RelBuilder relBuilder = call.builder();
+ relBuilder.push(exchange);
+ ImmutableBitSet groupSet = ImmutableBitSet.range(groupCount);
+ relBuilder.aggregate(relBuilder.groupKey(groupSet,
ImmutableList.of(groupSet)), aggCalls);
+ List<RelHint> newHints =
+ PinotHintStrategyTable.replaceHintOptions(aggRel.getHints(),
PinotHintOptions.INTERNAL_AGG_OPTIONS,
+ PinotHintOptions.InternalAggregateOptions.AGG_TYPE,
aggType.name());
+ relBuilder.hints(newHints);
return relBuilder.build();
}
- private static AggregateCall buildAggregateCall(RelNode inputNode,
AggregateCall orgAggCall, List<Integer> argList,
- int numberGroups, AggType aggType) {
- final SqlAggFunction oldAggFunction = orgAggCall.getAggregation();
- final SqlKind aggKind = oldAggFunction.getKind();
- String functionName = getFunctionNameFromAggregateCall(orgAggCall);
- AggregationFunctionType functionType =
AggregationFunctionType.getAggregationFunctionType(functionName);
- // create the aggFunction
- SqlAggFunction sqlAggFunction;
- if (functionType.getIntermediateReturnTypeInference() != null) {
- switch (aggType) {
- case LEAF:
- sqlAggFunction = new
PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null,
- functionType.getSqlKind(),
functionType.getIntermediateReturnTypeInference(), null,
- functionType.getOperandTypeChecker(),
functionType.getSqlFunctionCategory());
- break;
- case INTERMEDIATE:
- sqlAggFunction = new
PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null,
- functionType.getSqlKind(),
functionType.getIntermediateReturnTypeInference(), null,
- OperandTypes.ANY, functionType.getSqlFunctionCategory());
- break;
- case FINAL:
- sqlAggFunction = new
PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null,
- functionType.getSqlKind(),
ReturnTypes.explicit(orgAggCall.getType()), null,
- OperandTypes.ANY, functionType.getSqlFunctionCategory());
- break;
- default:
- throw new UnsupportedOperationException("Unsuppoted aggType: " +
aggType + " for " + functionName);
+ private static List<AggregateCall> buildAggCalls(Aggregate aggRel, AggType
aggType) {
+ RelNode input = PinotRuleUtils.unboxRel(aggRel.getInput());
+ List<RexNode> projects = (input instanceof Project) ? ((Project)
input).getProjects() : null;
+ List<AggregateCall> orgAggCalls = aggRel.getAggCallList();
+ List<AggregateCall> aggCalls = new ArrayList<>(orgAggCalls.size());
+ for (AggregateCall orgAggCall : orgAggCalls) {
+ // Generate rexList from argList and replace literal reference with
literal. Keep the first argument as is.
+ List<Integer> argList = orgAggCall.getArgList();
+ int numArguments = argList.size();
+ List<RexNode> rexList;
+ if (numArguments == 0) {
+ rexList = ImmutableList.of();
+ } else if (numArguments == 1) {
+ rexList = ImmutableList.of(RexInputRef.of(argList.get(0),
input.getRowType()));
+ } else {
+ rexList = new ArrayList<>(numArguments);
+ rexList.add(RexInputRef.of(argList.get(0), input.getRowType()));
+ for (int i = 1; i < numArguments; i++) {
+ int argument = argList.get(i);
+ if (projects != null && projects.get(argument) instanceof
RexLiteral) {
+ rexList.add(projects.get(argument));
+ } else {
+ rexList.add(RexInputRef.of(argument, input.getRowType()));
+ }
+ }
}
- } else {
- sqlAggFunction = oldAggFunction;
+ aggCalls.add(buildAggCall(input, orgAggCall, rexList,
aggRel.getGroupCount(), aggType));
}
-
- return AggregateCall.create(sqlAggFunction,
- functionName.equals("distinctCount") || orgAggCall.isDistinct(),
- orgAggCall.isApproximate(),
- orgAggCall.ignoreNulls(),
- argList,
- aggType.isInputIntermediateFormat() ? -1 : orgAggCall.filterArg,
- orgAggCall.distinctKeys,
- orgAggCall.collation,
- numberGroups,
- inputNode,
- null,
- null);
+ return aggCalls;
}
- private static String getFunctionNameFromAggregateCall(AggregateCall
aggregateCall) {
- return aggregateCall.getAggregation().getName().equalsIgnoreCase("COUNT")
&& aggregateCall.isDistinct()
- ? "distinctCount" : aggregateCall.getAggregation().getName();
+ // TODO: Revisit the following logic:
+ // - DISTINCT is resolved here
+ // - argList is replaced with rexList
+ private static AggregateCall buildAggCall(RelNode input, AggregateCall
orgAggCall, List<RexNode> rexList,
+ int numGroups, AggType aggType) {
+ String functionName = orgAggCall.getAggregation().getName();
+ if (orgAggCall.isDistinct()) {
+ if (functionName.equals("COUNT")) {
+ functionName = "DISTINCTCOUNT";
+ } else if (functionName.equals("LISTAGG")) {
+ rexList.add(input.getCluster().getRexBuilder().makeLiteral(true));
+ }
+ }
+ AggregationFunctionType functionType =
AggregationFunctionType.getAggregationFunctionType(functionName);
+ SqlAggFunction sqlAggFunction;
+ switch (aggType) {
+ case DIRECT:
+ sqlAggFunction = new PinotSqlAggFunction(functionName, null,
functionType.getSqlKind(),
+ ReturnTypes.explicit(orgAggCall.getType()), null,
functionType.getOperandTypeChecker(),
+ functionType.getSqlFunctionCategory());
+ break;
+ case LEAF:
+ sqlAggFunction = new PinotSqlAggFunction(functionName, null,
functionType.getSqlKind(),
+ functionType.getIntermediateReturnTypeInference(), null,
functionType.getOperandTypeChecker(),
+ functionType.getSqlFunctionCategory());
+ break;
+ case INTERMEDIATE:
+ sqlAggFunction = new PinotSqlAggFunction(functionName, null,
functionType.getSqlKind(),
+ functionType.getIntermediateReturnTypeInference(), null,
OperandTypes.ANY,
+ functionType.getSqlFunctionCategory());
+ break;
+ case FINAL:
+ sqlAggFunction = new PinotSqlAggFunction(functionName, null,
functionType.getSqlKind(),
+ ReturnTypes.explicit(orgAggCall.getType()), null,
OperandTypes.ANY, functionType.getSqlFunctionCategory());
+ break;
+ default:
+ throw new IllegalStateException("Unsupported AggType: " + aggType);
+ }
+ return AggregateCall.create(sqlAggFunction, false,
orgAggCall.isApproximate(), orgAggCall.ignoreNulls(), rexList,
+ ImmutableList.of(), aggType.isInputIntermediateFormat() ? -1 :
orgAggCall.filterArg, orgAggCall.distinctKeys,
+ orgAggCall.collation, numGroups, input, null, null);
}
}
diff --git
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateLiteralAttachmentRule.java
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateLiteralAttachmentRule.java
deleted file mode 100644
index 74af35b47a..0000000000
---
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateLiteralAttachmentRule.java
+++ /dev/null
@@ -1,107 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.pinot.calcite.rel.rules;
-
-import com.google.common.collect.ImmutableList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import org.apache.calcite.plan.RelOptRule;
-import org.apache.calcite.plan.RelOptRuleCall;
-import org.apache.calcite.rel.RelNode;
-import org.apache.calcite.rel.core.Aggregate;
-import org.apache.calcite.rel.core.AggregateCall;
-import org.apache.calcite.rel.core.Project;
-import org.apache.calcite.rel.hint.RelHint;
-import org.apache.calcite.rel.logical.LogicalAggregate;
-import org.apache.calcite.rex.RexLiteral;
-import org.apache.calcite.rex.RexNode;
-import org.apache.calcite.tools.RelBuilderFactory;
-import org.apache.calcite.util.Pair;
-import org.apache.pinot.calcite.rel.hint.PinotHintOptions;
-import org.apache.pinot.calcite.rel.hint.PinotHintStrategyTable;
-import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
-import org.apache.pinot.query.planner.logical.LiteralHintUtils;
-import org.apache.pinot.query.planner.logical.RexExpression;
-import org.apache.pinot.query.planner.logical.RexExpressionUtils;
-
-
-/**
- * Special rule to attach Literal to Aggregate call.
- */
-public class PinotAggregateLiteralAttachmentRule extends RelOptRule {
- public static final PinotAggregateLiteralAttachmentRule INSTANCE =
- new
PinotAggregateLiteralAttachmentRule(PinotRuleUtils.PINOT_REL_FACTORY);
-
- public PinotAggregateLiteralAttachmentRule(RelBuilderFactory factory) {
- super(operand(LogicalAggregate.class, any()), factory, null);
- }
-
- @Override
- public boolean matches(RelOptRuleCall call) {
- if (call.rels.length < 1) {
- return false;
- }
- if (call.rel(0) instanceof Aggregate) {
- Aggregate agg = call.rel(0);
- ImmutableList<RelHint> hints = agg.getHints();
- return !PinotHintStrategyTable.containsHintOption(hints,
- PinotHintOptions.INTERNAL_AGG_OPTIONS,
PinotHintOptions.InternalAggregateOptions.AGG_CALL_SIGNATURE);
- }
- return false;
- }
-
- @Override
- public void onMatch(RelOptRuleCall call) {
- Aggregate aggregate = call.rel(0);
- Map<Pair<Integer, Integer>, RexExpression.Literal> rexLiterals =
extractLiterals(call);
- List<RelHint> newHints =
PinotHintStrategyTable.replaceHintOptions(aggregate.getHints(),
- PinotHintOptions.INTERNAL_AGG_OPTIONS,
PinotHintOptions.InternalAggregateOptions.AGG_CALL_SIGNATURE,
- LiteralHintUtils.literalMapToHintString(rexLiterals));
- // TODO: validate against AggregationFunctionType to see if expected
literal positions are properly attached
- call.transformTo(new LogicalAggregate(aggregate.getCluster(),
aggregate.getTraitSet(), newHints,
- aggregate.getInput(), aggregate.getGroupSet(),
aggregate.getGroupSets(), aggregate.getAggCallList()));
- }
-
- private static Map<Pair<Integer, Integer>, RexExpression.Literal>
extractLiterals(RelOptRuleCall call) {
- Aggregate aggregate = call.rel(0);
- RelNode input = PinotRuleUtils.unboxRel(aggregate.getInput());
- List<RexNode> rexNodes = (input instanceof Project) ? ((Project)
input).getProjects() : null;
- List<AggregateCall> aggCallList = aggregate.getAggCallList();
- final Map<Pair<Integer, Integer>, RexExpression.Literal> rexLiteralMap =
new HashMap<>();
- for (int aggIdx = 0; aggIdx < aggCallList.size(); aggIdx++) {
- AggregateCall aggCall = aggCallList.get(aggIdx);
- int argSize = aggCall.getArgList().size();
- if (argSize > 1) {
- // use -1 argIdx to indicate size of the agg operands.
- rexLiteralMap.put(new Pair<>(aggIdx, -1), new
RexExpression.Literal(ColumnDataType.INT, argSize));
- // put the literals in to the map.
- for (int argIdx = 0; argIdx < argSize; argIdx++) {
- if (rexNodes != null) {
- RexNode field = rexNodes.get(aggCall.getArgList().get(argIdx));
- if (field instanceof RexLiteral) {
- rexLiteralMap.put(new Pair<>(aggIdx, argIdx),
RexExpressionUtils.fromRexLiteral((RexLiteral) field));
- }
- }
- }
- }
- }
- return rexLiteralMap;
- }
-}
diff --git
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java
index cbac4de9e3..6c2498c70b 100644
---
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java
+++
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java
@@ -117,11 +117,6 @@ public class PinotQueryRuleSets {
PruneEmptyRules.UNION_INSTANCE
);
- // Pinot specific rules to run using a single RuleCollection since we attach
aggregate info after optimizer.
- public static final Collection<RelOptRule> PINOT_AGG_PROCESS_RULES =
ImmutableList.of(
- PinotAggregateLiteralAttachmentRule.INSTANCE
- );
-
// Pinot specific rules that should be run AFTER all other rules
public static final Collection<RelOptRule> PINOT_POST_RULES =
ImmutableList.of(
// Evaluate the Literal filter nodes
diff --git
a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
index 059faac2d4..9c53cdee6a 100644
---
a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
+++
b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
@@ -328,10 +328,6 @@ public class QueryEnvironment {
hepProgramBuilder.addRuleInstance(relOptRule);
}
- // ----
- // Run Pinot rule to attach aggregation auxiliary info
-
hepProgramBuilder.addRuleCollection(PinotQueryRuleSets.PINOT_AGG_PROCESS_RULES);
-
// ----
// Pushdown filters using a single HepInstruction.
hepProgramBuilder.addRuleCollection(PinotQueryRuleSets.FILTER_PUSHDOWN_RULES);
diff --git
a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java
b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java
index debe59d0ab..1862adf95e 100644
---
a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java
+++
b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java
@@ -231,7 +231,7 @@ public class CalciteRexExpressionParser {
}
break;
default:
- functionName = functionKind.name();
+ functionName = canonicalizeFunctionName(functionKind.name());
break;
}
List<RexExpression> childNodes = rexCall.getFunctionOperands();
@@ -288,7 +288,7 @@ public class CalciteRexExpressionParser {
private static Expression getFunctionExpression(String canonicalName) {
Expression expression = new Expression(ExpressionType.FUNCTION);
- Function function = new Function(canonicalizeFunctionName(canonicalName));
+ Function function = new Function(canonicalName);
expression.setFunctionCall(function);
return expression;
}
diff --git
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/LiteralHintUtils.java
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/LiteralHintUtils.java
deleted file mode 100644
index ea854e9aba..0000000000
---
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/LiteralHintUtils.java
+++ /dev/null
@@ -1,85 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.pinot.query.planner.logical;
-
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import org.apache.calcite.util.Pair;
-import org.apache.commons.lang3.StringUtils;
-import org.apache.pinot.common.request.Literal;
-import org.apache.pinot.spi.data.FieldSpec;
-import org.apache.pinot.spi.utils.BytesUtils;
-
-
-public class LiteralHintUtils {
- private LiteralHintUtils() {
- }
-
- public static String literalMapToHintString(Map<Pair<Integer, Integer>,
RexExpression.Literal> literals) {
- List<String> literalStrings = new ArrayList<>(literals.size());
- for (Map.Entry<Pair<Integer, Integer>, RexExpression.Literal> e :
literals.entrySet()) {
- // individual literal parts are joined with `|`
- literalStrings.add(
- String.format("%d|%d|%s|%s", e.getKey().left, e.getKey().right,
e.getValue().getDataType().name(),
- e.getValue().getValue()));
- }
- // semi-colon is used to separate between encoded literals
- return "{" + StringUtils.join(literalStrings, ";:;") + "}";
- }
-
- public static Map<Integer, Map<Integer, Literal>>
hintStringToLiteralMap(String literalString) {
- Map<Integer, Map<Integer, Literal>> aggCallToLiteralArgsMap = new
HashMap<>();
- if (StringUtils.isNotEmpty(literalString) && !"{}".equals(literalString)) {
- String[] literalStringArr = literalString.substring(1,
literalString.length() - 1).split(";:;");
- for (String literalStr : literalStringArr) {
- String[] literalStrParts = literalStr.split("\\|", 4);
- int aggIdx = Integer.parseInt(literalStrParts[0]);
- int argListIdx = Integer.parseInt(literalStrParts[1]);
- String dataTypeNameStr = literalStrParts[2];
- String valueStr = literalStrParts[3];
- Map<Integer, Literal> literalArgs =
aggCallToLiteralArgsMap.computeIfAbsent(aggIdx, i -> new HashMap<>());
- literalArgs.put(argListIdx, stringToLiteral(dataTypeNameStr,
valueStr));
- }
- }
- return aggCallToLiteralArgsMap;
- }
-
- private static Literal stringToLiteral(String dataTypeStr, String valueStr) {
- FieldSpec.DataType dataType = FieldSpec.DataType.valueOf(dataTypeStr);
- switch (dataType) {
- case BOOLEAN:
- return Literal.boolValue(valueStr.equals("1"));
- case INT:
- return Literal.intValue(Integer.parseInt(valueStr));
- case LONG:
- return Literal.longValue(Long.parseLong(valueStr));
- case FLOAT:
- case DOUBLE:
- return Literal.doubleValue(Double.parseDouble(valueStr));
- case STRING:
- return Literal.stringValue(valueStr);
- case BYTES:
- return Literal.binaryValue(BytesUtils.toBytes(valueStr));
- default:
- throw new UnsupportedOperationException("Unsupported RexLiteral type:
" + dataTypeStr);
- }
- }
-}
diff --git
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java
index 5a80cd2596..c2e9890358 100644
---
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java
+++
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java
@@ -246,8 +246,10 @@ public class RexExpressionUtils {
}
public static RexExpression fromAggregateCall(AggregateCall aggregateCall) {
- List<RexExpression> operands =
-
aggregateCall.getArgList().stream().map(RexExpression.InputRef::new).collect(Collectors.toList());
+ List<RexExpression> operands = new
ArrayList<>(aggregateCall.rexList.size());
+ for (RexNode rexNode : aggregateCall.rexList) {
+ operands.add(fromRexNode(rexNode));
+ }
return new
RexExpression.FunctionCall(aggregateCall.getAggregation().getKind(),
RelToPlanNodeConverter.convertToColumnDataType(aggregateCall.getType()),
aggregateCall.getAggregation().getName(), operands,
aggregateCall.isDistinct());
diff --git
a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
index 810202ca49..8e74660e7a 100644
---
a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
+++
b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
@@ -255,7 +255,8 @@ public class QueryCompilationTest extends
QueryEnvironmentTestBase {
public void testQueryWithHint() {
// Hinting the query to use final stage aggregation makes server directly
return final result
// This is useful when data is already partitioned by col1
- String query = "SELECT /*+ aggOptionsInternal(agg_type='DIRECT') */ col1,
COUNT(*) FROM b GROUP BY col1";
+ String query =
+ "SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true') */
col1, COUNT(*) FROM b GROUP BY col1";
DispatchableSubPlan dispatchableSubPlan =
_queryEnvironment.planQuery(query);
List<DispatchablePlanFragment> stagePlans =
dispatchableSubPlan.getQueryStageList();
int numStages = stagePlans.size();
diff --git a/pinot-query-planner/src/test/resources/queries/GroupByPlans.json
b/pinot-query-planner/src/test/resources/queries/GroupByPlans.json
index a7a4b1a8be..8a0878d6e1 100644
--- a/pinot-query-planner/src/test/resources/queries/GroupByPlans.json
+++ b/pinot-query-planner/src/test/resources/queries/GroupByPlans.json
@@ -102,7 +102,7 @@
"sql": "EXPLAIN PLAN FOR SELECT /*+
aggOptions(is_skip_leaf_stage_group_by='true') */ a.col1, SUM(a.col3) FROM a
GROUP BY a.col1",
"output": [
"Execution Plan",
- "\nLogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])",
+ "\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
"\n LogicalProject(col1=[$0], col3=[$2])",
"\n LogicalTableScan(table=[[default, a]])",
@@ -128,7 +128,7 @@
"output": [
"Execution Plan",
"\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[/(CAST($1):DOUBLE
NOT NULL, $2)], EXPR$3=[$3], EXPR$4=[$4])",
- "\n LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)],
agg#1=[COUNT()], EXPR$3=[MAX($1)], EXPR$4=[MIN($1)])",
+ "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)],
agg#1=[COUNT()], agg#2=[MAX($1)], agg#3=[MIN($1)])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
"\n LogicalProject(col1=[$0], col3=[$2])",
"\n LogicalTableScan(table=[[default, a]])",
@@ -140,7 +140,7 @@
"sql": "EXPLAIN PLAN FOR SELECT /*+
aggOptions(is_skip_leaf_stage_group_by='true') */ a.col1, SUM(a.col3) FROM a
WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY a.col1",
"output": [
"Execution Plan",
- "\nLogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])",
+ "\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
"\n LogicalProject(col1=[$0], col3=[$2])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1,
_UTF-8'a'))])",
@@ -153,7 +153,7 @@
"sql": "EXPLAIN PLAN FOR SELECT /*+
aggOptions(is_skip_leaf_stage_group_by='true') */ a.col1, SUM(a.col3),
MAX(a.col3) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY a.col1",
"output": [
"Execution Plan",
- "\nLogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)],
EXPR$2=[MAX($1)])",
+ "\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)],
agg#1=[MAX($1)])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
"\n LogicalProject(col1=[$0], col3=[$2])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1,
_UTF-8'a'))])",
@@ -167,7 +167,7 @@
"notes": "TODO: Needs follow up. Project should only keep a.col1 since
the other columns are pushed to the filter, but it currently keeps them all",
"output": [
"Execution Plan",
- "\nLogicalAggregate(group=[{0}], EXPR$1=[COUNT()])",
+ "\nLogicalAggregate(group=[{0}], agg#0=[COUNT()])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
"\n LogicalProject(col1=[$0])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1,
_UTF-8'a'))])",
@@ -181,7 +181,7 @@
"output": [
"Execution Plan",
"\nLogicalProject(col2=[$1], col1=[$0], EXPR$2=[$2])",
- "\n LogicalAggregate(group=[{0, 1}], EXPR$2=[$SUM0($2)])",
+ "\n LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
"\n PinotLogicalExchange(distribution=[hash[0, 1]])",
"\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($0,
_UTF-8'a'))])",
@@ -196,7 +196,7 @@
"Execution Plan",
"\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[$2])",
"\n LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20),
<=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])",
- "\n LogicalAggregate(group=[{0}], EXPR$1=[COUNT()],
EXPR$2=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)])",
+ "\n LogicalAggregate(group=[{0}], agg#0=[COUNT()],
agg#1=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
"\n LogicalProject(col1=[$0], col3=[$2])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1,
_UTF-8'a'))])",
@@ -211,7 +211,7 @@
"Execution Plan",
"\nLogicalProject(col1=[$0], EXPR$1=[$1])",
"\n LogicalFilter(condition=[AND(>=($2, 0), <($3, 20), <=($1, 10),
=(/(CAST($1):DOUBLE NOT NULL, $4), 5))])",
- "\n LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)],
agg#1=[MAX($1)], agg#2=[MIN($1)], agg#3=[COUNT()])",
+ "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)],
agg#1=[MAX($1)], agg#2=[MIN($1)], agg#3=[COUNT()])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
"\n LogicalProject(col1=[$0], col3=[$2])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1,
_UTF-8'a'))])",
@@ -226,7 +226,7 @@
"Execution Plan",
"\nLogicalProject(value1=[$0], count=[$1], SUM=[$2])",
"\n LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20),
<=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])",
- "\n LogicalAggregate(group=[{0}], count=[COUNT()],
SUM=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)])",
+ "\n LogicalAggregate(group=[{0}], agg#0=[COUNT()],
agg#1=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
"\n LogicalProject(col1=[$0], col3=[$2])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1,
_UTF-8'a'))])",
diff --git a/pinot-query-planner/src/test/resources/queries/OrderByPlans.json
b/pinot-query-planner/src/test/resources/queries/OrderByPlans.json
index 7b97f583ea..32d1eb65f8 100644
--- a/pinot-query-planner/src/test/resources/queries/OrderByPlans.json
+++ b/pinot-query-planner/src/test/resources/queries/OrderByPlans.json
@@ -93,7 +93,7 @@
"Execution Plan",
"\nLogicalSort(sort0=[$0], dir0=[ASC])",
"\n PinotLogicalSortExchange(distribution=[hash], collation=[[0]],
isSortOnSender=[false], isSortOnReceiver=[true])",
- "\n LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])",
+ "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
"\n LogicalProject(col1=[$0], col3=[$2])",
"\n LogicalTableScan(table=[[default, a]])",
@@ -121,7 +121,7 @@
"Execution Plan",
"\nLogicalSort(sort0=[$0], dir0=[ASC])",
"\n PinotLogicalSortExchange(distribution=[hash], collation=[[0]],
isSortOnSender=[false], isSortOnReceiver=[true])",
- "\n LogicalAggregate(group=[{0}], sum=[$SUM0($1)])",
+ "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
"\n LogicalProject(col1=[$0], col3=[$2])",
"\n LogicalTableScan(table=[[default, a]])",
diff --git
a/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json
b/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json
index 5841c442ff..3f0a4cd0f0 100644
--- a/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json
+++ b/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json
@@ -100,10 +100,10 @@
},
{
"description": "semi-join with dynamic_broadcast join strategy then
group-by on same key",
- "sql": "EXPLAIN PLAN FOR SELECT /*+
aggOptionsInternal(agg_type='DIRECT') */ a.col1, SUM(a.col3) FROM a WHERE
a.col1 IN (SELECT col2 FROM b WHERE b.col3 > 0) GROUP BY 1",
+ "sql": "EXPLAIN PLAN FOR SELECT /*+
aggOptions(is_partitioned_by_group_by_keys='true') */ a.col1, SUM(a.col3) FROM
a WHERE a.col1 IN (SELECT col2 FROM b WHERE b.col3 > 0) GROUP BY 1",
"output": [
"Execution Plan",
- "\nLogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])",
+ "\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
"\n LogicalJoin(condition=[=($0, $2)], joinType=[semi])",
"\n LogicalProject(col1=[$0], col3=[$2])",
"\n LogicalTableScan(table=[[default, a]])",
@@ -138,7 +138,7 @@
"output": [
"Execution Plan",
"\nLogicalProject(col2=[$1], col1=[$0], EXPR$2=[$2])",
- "\n LogicalAggregate(group=[{0, 1}], EXPR$2=[$SUM0($2)])",
+ "\n LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
"\n PinotLogicalExchange(distribution=[hash[0, 1]])",
"\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($0,
_UTF-8'a'))])",
@@ -153,7 +153,7 @@
"Execution Plan",
"\nLogicalProject(col2=[$0], EXPR$1=[$1], EXPR$2=[$2], EXPR$3=[$3])",
"\n LogicalFilter(condition=[AND(>($1, 10), >=($4, 0), <($5, 20),
<=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])",
- "\n LogicalAggregate(group=[{0}], EXPR$1=[COUNT()],
EXPR$2=[$SUM0($1)], EXPR$3=[$SUM0($2)], agg#3=[MAX($1)], agg#4=[MIN($1)])",
+ "\n LogicalAggregate(group=[{0}], agg#0=[COUNT()],
agg#1=[$SUM0($1)], agg#2=[$SUM0($2)], agg#3=[MAX($1)], agg#4=[MIN($1)])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
"\n LogicalProject(col2=[$1], col3=[$2],
$f2=[CAST($0):DECIMAL(1000, 500) NOT NULL])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1,
_UTF-8'a'))])",
@@ -162,24 +162,11 @@
]
},
{
- "description": "aggregate with skip intermediate stage hint (via
hinting the leaf stage group by as final stage_",
- "sql": "EXPLAIN PLAN FOR SELECT /*+
aggOptionsInternal(agg_type='DIRECT') */ a.col2, COUNT(*), SUM(a.col3),
SUM(a.col1) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY a.col2 HAVING
COUNT(*) > 10",
- "output": [
- "Execution Plan",
- "\nLogicalFilter(condition=[>($1, 10)])",
- "\n LogicalAggregate(group=[{0}], EXPR$1=[COUNT()],
EXPR$2=[$SUM0($1)], EXPR$3=[$SUM0($2)])",
- "\n LogicalProject(col2=[$1], col3=[$2],
$f2=[CAST($0):DECIMAL(1000, 500) NOT NULL])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($1,
_UTF-8'a'))])",
- "\n LogicalTableScan(table=[[default, a]])",
- "\n"
- ]
- },
- {
- "description": "aggregate with skip leaf stage hint (via hint option
is_partitioned_by_group_by_keys",
+ "description": "aggregate with skip intermediate stage hint (via hint
option is_partitioned_by_group_by_keys)",
"sql": "EXPLAIN PLAN FOR SELECT /*+
aggOptions(is_partitioned_by_group_by_keys='true') */ a.col2, COUNT(*),
SUM(a.col3), SUM(a.col1) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY
a.col2",
"output": [
"Execution Plan",
- "\nLogicalAggregate(group=[{0}], EXPR$1=[COUNT()],
EXPR$2=[$SUM0($1)], EXPR$3=[$SUM0($2)])",
+ "\nLogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($1)],
agg#2=[$SUM0($2)])",
"\n LogicalProject(col2=[$1], col3=[$2],
$f2=[CAST($0):DECIMAL(1000, 500) NOT NULL])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])",
"\n LogicalTableScan(table=[[default, a]])",
@@ -409,7 +396,7 @@
"sql": "EXPLAIN PLAN FOR SELECT /*+
aggOptions(is_partitioned_by_group_by_keys='true') */ a.col2, SUM(a.col3) FROM
a /*+ tableOptions(partition_function='hashcode', partition_key='col2',
partition_size='4') */ WHERE a.col2 IN (SELECT col1 FROM b /*+
tableOptions(partition_function='hashcode', partition_key='col1',
partition_size='4') */ WHERE b.col3 > 0) GROUP BY 1",
"output": [
"Execution Plan",
- "\nLogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])",
+ "\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
"\n LogicalJoin(condition=[=($0, $2)], joinType=[semi])",
"\n LogicalProject(col2=[$1], col3=[$2])",
"\n LogicalTableScan(table=[[default, a]])",
@@ -425,7 +412,7 @@
"sql": "EXPLAIN PLAN FOR SELECT /*+
aggOptions(is_partitioned_by_group_by_keys='true') */ a.col2, SUM(a.col3) FROM
a /*+ tableOptions(partition_function='hashcode', partition_key='col2',
partition_size='4') */ WHERE a.col2 IN (SELECT col1 FROM b WHERE b.col3 > 0)
GROUP BY 1",
"output": [
"Execution Plan",
- "\nLogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])",
+ "\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
"\n LogicalJoin(condition=[=($0, $2)], joinType=[semi])",
"\n LogicalProject(col2=[$1], col3=[$2])",
"\n LogicalTableScan(table=[[default, a]])",
@@ -443,7 +430,7 @@
"Execution Plan",
"\nLogicalProject(col2=[$0], EXPR$1=[$1])",
"\n LogicalFilter(condition=[>($2, 5)])",
- "\n LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)],
agg#1=[COUNT()])",
+ "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)],
agg#1=[COUNT()])",
"\n LogicalJoin(condition=[=($0, $2)], joinType=[semi])",
"\n LogicalProject(col2=[$1], col3=[$2])",
"\n LogicalTableScan(table=[[default, a]])",
@@ -461,7 +448,7 @@
"Execution Plan",
"\nLogicalSort(sort0=[$1], dir0=[DESC])",
"\n PinotLogicalSortExchange(distribution=[hash], collation=[[1
DESC]], isSortOnSender=[false], isSortOnReceiver=[true])",
- "\n LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])",
+ "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
"\n LogicalJoin(condition=[=($0, $2)], joinType=[semi])",
"\n LogicalProject(col2=[$1], col3=[$2])",
"\n LogicalTableScan(table=[[default, a]])",
diff --git
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
index 7cf7d5f2a7..a19ff64d4e 100644
---
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
+++
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
@@ -27,10 +27,8 @@ import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import org.apache.calcite.sql.SqlKind;
-import org.apache.pinot.calcite.rel.hint.PinotHintOptions;
import org.apache.pinot.common.datablock.DataBlock;
import org.apache.pinot.common.datatable.StatMap;
-import org.apache.pinot.common.request.Literal;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.request.context.FunctionContext;
import org.apache.pinot.common.utils.DataSchema;
@@ -43,13 +41,13 @@ import
org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import
org.apache.pinot.core.query.aggregation.function.AggregationFunctionFactory;
import
org.apache.pinot.core.query.aggregation.function.CountAggregationFunction;
import org.apache.pinot.core.util.DataBlockExtractUtils;
-import org.apache.pinot.query.planner.logical.LiteralHintUtils;
import org.apache.pinot.query.planner.logical.RexExpression;
import org.apache.pinot.query.planner.plannode.AbstractPlanNode;
import org.apache.pinot.query.planner.plannode.AggregateNode.AggType;
import org.apache.pinot.query.runtime.blocks.TransferableBlock;
import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
-import org.apache.pinot.segment.spi.AggregationFunctionType;
+import org.apache.pinot.spi.data.FieldSpec.DataType;
+import org.apache.pinot.spi.utils.BooleanUtils;
import org.roaringbitmap.RoaringBitmap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -65,11 +63,9 @@ public class AggregateOperator extends MultiStageOperator {
private static final String EXPLAIN_NAME = "AGGREGATE_OPERATOR";
private static final CountAggregationFunction COUNT_STAR_AGG_FUNCTION =
new
CountAggregationFunction(Collections.singletonList(ExpressionContext.forIdentifier("*")),
false);
- private static final ExpressionContext PLACEHOLDER_IDENTIFIER =
ExpressionContext.forIdentifier("__PLACEHOLDER__");
private final MultiStageOperator _inputOperator;
private final DataSchema _resultSchema;
- private final AggType _aggType;
private final MultistageAggregationExecutor _aggregationExecutor;
private final MultistageGroupByExecutor _groupByExecutor;
@Nullable
@@ -78,29 +74,15 @@ public class AggregateOperator extends MultiStageOperator {
private boolean _hasConstructedAggregateBlock;
- public AggregateOperator(OpChainExecutionContext context, MultiStageOperator
inputOperator,
- DataSchema resultSchema, List<RexExpression> aggCalls,
List<RexExpression> groupSet, AggType aggType,
- List<Integer> filterArgIndices, @Nullable AbstractPlanNode.NodeHint
nodeHint) {
+ public AggregateOperator(OpChainExecutionContext context, MultiStageOperator
inputOperator, DataSchema resultSchema,
+ List<RexExpression> aggCalls, List<RexExpression> groupSet, AggType
aggType, List<Integer> filterArgIndices,
+ @Nullable AbstractPlanNode.NodeHint nodeHint) {
super(context);
_inputOperator = inputOperator;
_resultSchema = resultSchema;
- _aggType = aggType;
-
- // Process literal hints
- Map<Integer, Map<Integer, Literal>> literalArgumentsMap = null;
- if (nodeHint != null) {
- Map<String, String> aggOptions =
nodeHint._hintOptions.get(PinotHintOptions.INTERNAL_AGG_OPTIONS);
- if (aggOptions != null) {
- literalArgumentsMap = LiteralHintUtils.hintStringToLiteralMap(
-
aggOptions.get(PinotHintOptions.InternalAggregateOptions.AGG_CALL_SIGNATURE));
- }
- }
- if (literalArgumentsMap == null) {
- literalArgumentsMap = Collections.emptyMap();
- }
// Initialize the aggregation functions
- AggregationFunction<?, ?>[] aggFunctions = getAggFunctions(aggCalls,
literalArgumentsMap);
+ AggregationFunction<?, ?>[] aggFunctions = getAggFunctions(aggCalls);
// Process the filter argument indices
int numFunctions = aggFunctions.length;
@@ -214,27 +196,16 @@ public class AggregateOperator extends MultiStageOperator
{
return block;
}
- private AggregationFunction<?, ?>[] getAggFunctions(List<RexExpression>
aggCalls,
- Map<Integer, Map<Integer, Literal>> literalArgumentsMap) {
+ private AggregationFunction<?, ?>[] getAggFunctions(List<RexExpression>
aggCalls) {
int numFunctions = aggCalls.size();
AggregationFunction<?, ?>[] aggFunctions = new
AggregationFunction[numFunctions];
- if (!_aggType.isInputIntermediateFormat()) {
- for (int i = 0; i < numFunctions; i++) {
- Map<Integer, Literal> literalArguments =
literalArgumentsMap.getOrDefault(i, Collections.emptyMap());
- aggFunctions[i] =
getAggFunctionForRawInput((RexExpression.FunctionCall) aggCalls.get(i),
literalArguments);
- }
- } else {
- for (int i = 0; i < numFunctions; i++) {
- Map<Integer, Literal> literalArguments =
literalArgumentsMap.getOrDefault(i, Collections.emptyMap());
- aggFunctions[i] =
- getAggFunctionForIntermediateInput((RexExpression.FunctionCall)
aggCalls.get(i), literalArguments);
- }
+ for (int i = 0; i < numFunctions; i++) {
+ aggFunctions[i] = getAggFunction((RexExpression.FunctionCall)
aggCalls.get(i));
}
return aggFunctions;
}
- private AggregationFunction<?, ?>
getAggFunctionForRawInput(RexExpression.FunctionCall functionCall,
- Map<Integer, Literal> literalArguments) {
+ private AggregationFunction<?, ?> getAggFunction(RexExpression.FunctionCall
functionCall) {
String functionName = functionCall.getFunctionName();
List<RexExpression> operands = functionCall.getFunctionOperands();
int numArguments = operands.size();
@@ -244,78 +215,26 @@ public class AggregateOperator extends MultiStageOperator
{
return COUNT_STAR_AGG_FUNCTION;
}
List<ExpressionContext> arguments = new ArrayList<>(numArguments);
- for (int i = 0; i < numArguments; i++) {
- Literal literalArgument = literalArguments.get(i);
- if (literalArgument != null) {
- arguments.add(ExpressionContext.forLiteralContext(literalArgument));
+ for (RexExpression operand : operands) {
+ if (operand instanceof RexExpression.InputRef) {
+ RexExpression.InputRef inputRef = (RexExpression.InputRef) operand;
+
arguments.add(ExpressionContext.forIdentifier(fromColIdToIdentifier(inputRef.getIndex())));
} else {
- RexExpression operand = operands.get(i);
- switch (operand.getKind()) {
- case INPUT_REF:
- RexExpression.InputRef inputRef = (RexExpression.InputRef) operand;
-
arguments.add(ExpressionContext.forIdentifier(fromColIdToIdentifier(inputRef.getIndex())));
- break;
- case LITERAL:
- RexExpression.Literal literalRexExp = (RexExpression.Literal)
operand;
-
arguments.add(ExpressionContext.forLiteralContext(literalRexExp.getDataType().toDataType(),
- literalRexExp.getValue()));
- break;
- default:
- throw new IllegalStateException("Illegal aggregation function
operand type: " + operand.getKind());
+ assert operand instanceof RexExpression.Literal;
+ RexExpression.Literal literal = (RexExpression.Literal) operand;
+ DataType dataType = literal.getDataType().toDataType();
+ Object value = literal.getValue();
+ // TODO: Fix BOOLEAN literal to directly store true/false
+ if (dataType == DataType.BOOLEAN) {
+ value = BooleanUtils.fromNonNullInternalValue(value);
}
+ arguments.add(ExpressionContext.forLiteralContext(dataType, value));
}
}
- handleListAggDistinctArg(functionName, functionCall, arguments);
return AggregationFunctionFactory.getAggregationFunction(
new FunctionContext(FunctionContext.Type.AGGREGATION, functionName,
arguments), true);
}
- private static AggregationFunction<?, ?>
getAggFunctionForIntermediateInput(RexExpression.FunctionCall functionCall,
- Map<Integer, Literal> literalArguments) {
- String functionName = functionCall.getFunctionName();
- List<RexExpression> operands = functionCall.getFunctionOperands();
- int numArguments = operands.size();
- Preconditions.checkState(numArguments == 1, "Intermediate aggregate must
have 1 argument, got: %s", numArguments);
- RexExpression operand = operands.get(0);
- Preconditions.checkState(operand.getKind() == SqlKind.INPUT_REF,
- "Intermediate aggregate argument must be an input reference, got: %s",
operand.getKind());
- // We might need to append extra arguments extracted from the hint to
match the signature of the aggregation
- Literal numArgumentsLiteral = literalArguments.get(-1);
- if (numArgumentsLiteral == null) {
- return AggregationFunctionFactory.getAggregationFunction(
- new FunctionContext(FunctionContext.Type.AGGREGATION, functionName,
Collections.singletonList(
-
ExpressionContext.forIdentifier(fromColIdToIdentifier(((RexExpression.InputRef)
operand).getIndex())))),
- true);
- } else {
- int numExpectedArguments = numArgumentsLiteral.getIntValue();
- List<ExpressionContext> arguments = new
ArrayList<>(numExpectedArguments);
- arguments.add(
-
ExpressionContext.forIdentifier(fromColIdToIdentifier(((RexExpression.InputRef)
operand).getIndex())));
- for (int i = 1; i < numExpectedArguments; i++) {
- Literal literalArgument = literalArguments.get(i);
- if (literalArgument != null) {
- arguments.add(ExpressionContext.forLiteralContext(literalArgument));
- } else {
- arguments.add(PLACEHOLDER_IDENTIFIER);
- }
- }
- handleListAggDistinctArg(functionName, functionCall, arguments);
- return AggregationFunctionFactory.getAggregationFunction(
- new FunctionContext(FunctionContext.Type.AGGREGATION, functionName,
arguments), true);
- }
- }
-
- private static void handleListAggDistinctArg(String functionName,
RexExpression.FunctionCall functionCall,
- List<ExpressionContext> arguments) {
- String upperCaseFunctionName =
-
AggregationFunctionType.getNormalizedAggregationFunctionName(functionName);
- if (upperCaseFunctionName.equals("LISTAGG")) {
- if (functionCall.isDistinct()) {
-
arguments.add(ExpressionContext.forLiteralContext(Literal.boolValue(true)));
- }
- }
- }
-
private static String fromColIdToIdentifier(int colId) {
return "$" + colId;
}
diff --git a/pinot-query-runtime/src/test/resources/queries/QueryHints.json
b/pinot-query-runtime/src/test/resources/queries/QueryHints.json
index f8c850fcd3..81a939c2e1 100644
--- a/pinot-query-runtime/src/test/resources/queries/QueryHints.json
+++ b/pinot-query-runtime/src/test/resources/queries/QueryHints.json
@@ -275,7 +275,7 @@
},
{
"description": "semi-join with dynamic_broadcast join strategy then
group-by on same key",
- "sql": "SELECT /*+ aggOptionsInternal(agg_type='DIRECT') */
{tbl1}.num, SUM({tbl1}.val) FROM {tbl1} WHERE {tbl1}.name IN (SELECT id FROM
{tbl2} WHERE {tbl2}.data > 0) GROUP BY {tbl1}.num"
+ "sql": "SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true')
*/ {tbl1}.num, SUM({tbl1}.val) FROM {tbl1} WHERE {tbl1}.name IN (SELECT id FROM
{tbl2} WHERE {tbl2}.data > 0) GROUP BY {tbl1}.num"
},
{
"description": "semi-join with dynamic_broadcast join strategy then
group-by on different key",
@@ -290,11 +290,7 @@
"sql": "SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */
{tbl1}.num, COUNT(*), SUM({tbl1}.val), SUM({tbl1}.num) FROM {tbl1} WHERE
{tbl1}.val >= 0 AND {tbl1}.name != 'a' GROUP BY {tbl1}.num HAVING COUNT(*) > 10
AND MAX({tbl1}.val) >= 0 AND MIN({tbl1}.val) < 20 AND SUM({tbl1}.val) <= 10 AND
AVG({tbl1}.val) = 5"
},
{
- "description": "aggregate with skip intermediate stage hint (via
hinting the leaf stage group by as final stage_",
- "sql": "SELECT /*+ aggOptionsInternal(agg_type='DIRECT') */
{tbl1}.num, COUNT(*), SUM({tbl1}.val), SUM({tbl1}.num) FROM {tbl1} WHERE
{tbl1}.val >= 0 AND {tbl1}.name != 'a' GROUP BY {tbl1}.num HAVING COUNT(*) > 10"
- },
- {
- "description": "aggregate with skip leaf stage hint (via hint option
is_partitioned_by_group_by_keys",
+ "description": "aggregate with skip intermediate stage hint (via hint
option is_partitioned_by_group_by_keys)",
"sql": "SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true')
*/ {tbl1}.num, COUNT(*), SUM({tbl1}.val), SUM({tbl1}.num) FROM {tbl1} WHERE
{tbl1}.val >= 0 AND {tbl1}.name != 'a' GROUP BY {tbl1}.num"
},
{
diff --git
a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
index a6c468d8fe..877ac7f232 100644
---
a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
+++
b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
@@ -460,9 +460,10 @@ public enum AggregationFunctionType {
* <p>NOTE: Underscores in the function name are ignored.
*/
public static AggregationFunctionType getAggregationFunctionType(String
functionName) {
- if (functionName.regionMatches(true, 0, "percentile", 0, 10)) {
+ String normalizedFunctionName =
getNormalizedAggregationFunctionName(functionName);
+ if (normalizedFunctionName.regionMatches(false, 0, "PERCENTILE", 0, 10)) {
// This style of aggregation functions is not supported in the
multistage engine
- String remainingFunctionName =
getNormalizedAggregationFunctionName(functionName).substring(10).toUpperCase();
+ String remainingFunctionName =
normalizedFunctionName.substring(10).toUpperCase();
if (remainingFunctionName.isEmpty() ||
remainingFunctionName.matches("\\d+")) {
return PERCENTILE;
} else if (remainingFunctionName.equals("EST") ||
remainingFunctionName.matches("EST\\d+")) {
@@ -496,7 +497,7 @@ public enum AggregationFunctionType {
}
} else {
try {
- return
AggregationFunctionType.valueOf(getNormalizedAggregationFunctionName(functionName));
+ return AggregationFunctionType.valueOf(normalizedFunctionName);
} catch (IllegalArgumentException e) {
throw new IllegalArgumentException("Invalid aggregation function name:
" + functionName);
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]