This is an automated email from the ASF dual-hosted git repository.
jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 303b1a7cbe [Issue 7519] Adds support for multiple filtered/unfiltered
aggregations with GROUP BY (#10000)
303b1a7cbe is described below
commit 303b1a7cbe78244491f0580eb88e966a41b56b25
Author: Evan Galpin <[email protected]>
AuthorDate: Wed Jan 4 19:01:40 2023 -0800
[Issue 7519] Adds support for multiple filtered/unfiltered aggregations
with GROUP BY (#10000)
---
.../operator/query/FilteredGroupByOperator.java | 215 +++++++++++++++++++++
.../pinot/core/plan/AggregationPlanNode.java | 87 +--------
.../apache/pinot/core/plan/GroupByPlanNode.java | 30 ++-
.../function/AggregationFunctionUtils.java | 94 +++++++++
.../groupby/DefaultGroupByExecutor.java | 56 ++++--
.../query/aggregation/groupby/GroupByExecutor.java | 4 +
.../core/query/request/context/QueryContext.java | 6 +-
.../query/aggregation/groupby/GroupByTrimTest.java | 9 +-
.../pinot/queries/FilteredAggregationsTest.java | 57 +++++-
9 files changed, 445 insertions(+), 113 deletions(-)
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java
new file mode 100644
index 0000000000..e895d817dd
--- /dev/null
+++
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java
@@ -0,0 +1,215 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.operator.query;
+
+import java.util.Collection;
+import java.util.IdentityHashMap;
+import java.util.List;
+import java.util.stream.Collectors;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.common.Operator;
+import org.apache.pinot.core.data.table.IntermediateRecord;
+import org.apache.pinot.core.data.table.TableResizer;
+import org.apache.pinot.core.operator.BaseOperator;
+import org.apache.pinot.core.operator.ExecutionStatistics;
+import org.apache.pinot.core.operator.blocks.TransformBlock;
+import org.apache.pinot.core.operator.blocks.results.GroupByResultsBlock;
+import org.apache.pinot.core.operator.transform.TransformOperator;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
+import
org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult;
+import org.apache.pinot.core.query.aggregation.groupby.DefaultGroupByExecutor;
+import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
+import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator;
+import org.apache.pinot.core.query.request.context.QueryContext;
+import org.apache.pinot.core.util.GroupByUtils;
+import org.apache.pinot.spi.trace.Tracing;
+
+
+/**
+ * The <code>FilteredGroupByOperator</code> class provides the operator for
group-by query on a single segment when
+ * there are 1 or more filter expressions on aggregations.
+ */
+@SuppressWarnings("rawtypes")
+public class FilteredGroupByOperator extends BaseOperator<GroupByResultsBlock>
{
+ private static final String EXPLAIN_NAME = "GROUP_BY_FILTERED";
+
+ private final AggregationFunction[] _aggregationFunctions;
+ private final List<Pair<AggregationFunction[], TransformOperator>>
_aggFunctionsWithTransformOperator;
+ private final ExpressionContext[] _groupByExpressions;
+ private final long _numTotalDocs;
+ private long _numDocsScanned;
+ private long _numEntriesScannedInFilter;
+ private long _numEntriesScannedPostFilter;
+ private final DataSchema _dataSchema;
+ private final QueryContext _queryContext;
+
+ public FilteredGroupByOperator(AggregationFunction[] aggregationFunctions,
+ List<Pair<AggregationFunction[], TransformOperator>>
aggFunctionsWithTransformOperator,
+ ExpressionContext[] groupByExpressions, long numTotalDocs, QueryContext
queryContext) {
+ _aggregationFunctions = aggregationFunctions;
+ _aggFunctionsWithTransformOperator = aggFunctionsWithTransformOperator;
+ _groupByExpressions = groupByExpressions;
+ _numTotalDocs = numTotalDocs;
+ _queryContext = queryContext;
+
+ // NOTE: The indexedTable expects that the data schema will have group by
columns before aggregation columns
+ int numGroupByExpressions = groupByExpressions.length;
+ int numAggregationFunctions = aggregationFunctions.length;
+ int numColumns = numGroupByExpressions + numAggregationFunctions;
+ String[] columnNames = new String[numColumns];
+ DataSchema.ColumnDataType[] columnDataTypes = new
DataSchema.ColumnDataType[numColumns];
+
+ // Extract column names and data types for group-by columns
+ for (int i = 0; i < numGroupByExpressions; i++) {
+ ExpressionContext groupByExpression = groupByExpressions[i];
+ columnNames[i] = groupByExpression.toString();
+ columnDataTypes[i] = DataSchema.ColumnDataType.fromDataTypeSV(
+
aggFunctionsWithTransformOperator.get(i).getRight().getResultMetadata(groupByExpression).getDataType());
+ }
+
+ // Extract column names and data types for aggregation functions
+ for (int i = 0; i < numAggregationFunctions; i++) {
+ AggregationFunction aggregationFunction = aggregationFunctions[i];
+ int index = numGroupByExpressions + i;
+ columnNames[index] = aggregationFunction.getResultColumnName();
+ columnDataTypes[index] =
aggregationFunction.getIntermediateResultColumnType();
+ }
+
+ _dataSchema = new DataSchema(columnNames, columnDataTypes);
+ }
+
+ @Override
+ protected GroupByResultsBlock getNextBlock() {
+ // TODO(egalpin): Support Startree query resolution when possible, even
with FILTER expressions
+ int numAggregations = _aggregationFunctions.length;
+
+ GroupByResultHolder[] groupByResultHolders = new
GroupByResultHolder[numAggregations];
+ IdentityHashMap<AggregationFunction, Integer> resultHolderIndexMap = new
IdentityHashMap<>(numAggregations);
+ for (int i = 0; i < numAggregations; i++) {
+ resultHolderIndexMap.put(_aggregationFunctions[i], i);
+ }
+
+ GroupKeyGenerator groupKeyGenerator = null;
+ for (Pair<AggregationFunction[], TransformOperator> filteredAggregation :
_aggFunctionsWithTransformOperator) {
+ TransformOperator transformOperator = filteredAggregation.getRight();
+ AggregationFunction[] filteredAggFunctions =
filteredAggregation.getLeft();
+
+ // Perform aggregation group-by on all the blocks
+ DefaultGroupByExecutor groupByExecutor;
+ if (groupKeyGenerator == null) {
+ // The group key generator should be shared across all
AggregationFunctions so that agg results can be
+ // aligned. Given that filtered aggregations are stored as an iterable
of iterables so that all filtered aggs
+ // with the same filter can share transform blocks, rather than a
singular flat iterable in the case where
+ // aggs are all non-filtered, sharing a GroupKeyGenerator across all
aggs cannot be accomplished by allowing
+ // the GroupByExecutor to have sole ownership of the
GroupKeyGenerator. Therefore, we allow constructing a
+ // GroupByExecutor with a pre-existing GroupKeyGenerator so that the
GroupKeyGenerator can be shared across
+ // loop iterations i.e. across all aggs.
+ groupByExecutor =
+ new DefaultGroupByExecutor(_queryContext, filteredAggFunctions,
_groupByExpressions, transformOperator);
+ groupKeyGenerator = groupByExecutor.getGroupKeyGenerator();
+ } else {
+ groupByExecutor =
+ new DefaultGroupByExecutor(_queryContext, filteredAggFunctions,
_groupByExpressions, transformOperator,
+ groupKeyGenerator);
+ }
+
+ int numDocsScanned = 0;
+ TransformBlock transformBlock;
+ while ((transformBlock = transformOperator.nextBlock()) != null) {
+ numDocsScanned += transformBlock.getNumDocs();
+ groupByExecutor.process(transformBlock);
+ }
+
+ _numDocsScanned += numDocsScanned;
+ _numEntriesScannedInFilter +=
transformOperator.getExecutionStatistics().getNumEntriesScannedInFilter();
+ _numEntriesScannedPostFilter += (long) numDocsScanned *
transformOperator.getNumColumnsProjected();
+ GroupByResultHolder[] filterGroupByResults =
groupByExecutor.getGroupByResultHolders();
+ for (int i = 0; i < filteredAggFunctions.length; i++) {
+
groupByResultHolders[resultHolderIndexMap.get(filteredAggFunctions[i])] =
filterGroupByResults[i];
+ }
+ }
+ assert groupKeyGenerator != null;
+ for (GroupByResultHolder groupByResultHolder : groupByResultHolders) {
+ groupByResultHolder.ensureCapacity(groupKeyGenerator.getNumKeys());
+ }
+
+ // Check if the groups limit is reached
+ boolean numGroupsLimitReached = groupKeyGenerator.getNumKeys() >=
_queryContext.getNumGroupsLimit();
+ Tracing.activeRecording().setNumGroups(_queryContext.getNumGroupsLimit(),
groupKeyGenerator.getNumKeys());
+
+ // Trim the groups when iff:
+ // - Query has ORDER BY clause
+ // - Segment group trim is enabled
+ // - There are more groups than the trim size
+ // TODO: Currently the groups are not trimmed if there is no ordering
specified. Consider ordering on group-by
+ // columns if no ordering is specified.
+ int minGroupTrimSize = _queryContext.getMinSegmentGroupTrimSize();
+ if (_queryContext.getOrderByExpressions() != null && minGroupTrimSize > 0)
{
+ int trimSize = GroupByUtils.getTableCapacity(_queryContext.getLimit(),
minGroupTrimSize);
+ if (groupKeyGenerator.getNumKeys() > trimSize) {
+ TableResizer tableResizer = new TableResizer(_dataSchema,
_queryContext);
+ Collection<IntermediateRecord> intermediateRecords =
+ tableResizer.trimInSegmentResults(groupKeyGenerator,
groupByResultHolders, trimSize);
+ GroupByResultsBlock resultsBlock = new
GroupByResultsBlock(_dataSchema, intermediateRecords);
+ resultsBlock.setNumGroupsLimitReached(numGroupsLimitReached);
+ return resultsBlock;
+ }
+ }
+
+ AggregationGroupByResult aggGroupByResult =
+ new AggregationGroupByResult(groupKeyGenerator, _aggregationFunctions,
groupByResultHolders);
+ GroupByResultsBlock resultsBlock = new GroupByResultsBlock(_dataSchema,
aggGroupByResult);
+ resultsBlock.setNumGroupsLimitReached(numGroupsLimitReached);
+ return resultsBlock;
+ }
+
+ @Override
+ public List<Operator> getChildOperators() {
+ return
_aggFunctionsWithTransformOperator.stream().map(Pair::getRight).collect(Collectors.toList());
+ }
+
+ @Override
+ public ExecutionStatistics getExecutionStatistics() {
+ return new ExecutionStatistics(_numDocsScanned,
_numEntriesScannedInFilter, _numEntriesScannedPostFilter,
+ _numTotalDocs);
+ }
+
+ @Override
+ public String toExplainString() {
+ StringBuilder stringBuilder = new
StringBuilder(EXPLAIN_NAME).append("(groupKeys:");
+ if (_groupByExpressions.length > 0) {
+ stringBuilder.append(_groupByExpressions[0].toString());
+ for (int i = 1; i < _groupByExpressions.length; i++) {
+ stringBuilder.append(", ").append(_groupByExpressions[i].toString());
+ }
+ }
+
+ stringBuilder.append(", aggregations:");
+ if (_aggregationFunctions.length > 0) {
+ stringBuilder.append(_aggregationFunctions[0].toExplainString());
+ for (int i = 1; i < _aggregationFunctions.length; i++) {
+ stringBuilder.append(",
").append(_aggregationFunctions[i].toExplainString());
+ }
+ }
+
+ return stringBuilder.append(')').toString();
+ }
+}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java
b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java
index 58d74fb00f..148911897e 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java
@@ -18,19 +18,15 @@
*/
package org.apache.pinot.core.plan;
-import java.util.ArrayList;
import java.util.EnumSet;
-import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.pinot.common.request.context.ExpressionContext;
-import org.apache.pinot.common.request.context.FilterContext;
import org.apache.pinot.core.common.Operator;
import org.apache.pinot.core.operator.blocks.results.AggregationResultsBlock;
import org.apache.pinot.core.operator.filter.BaseFilterOperator;
-import org.apache.pinot.core.operator.filter.CombinedFilterOperator;
import org.apache.pinot.core.operator.query.AggregationOperator;
import org.apache.pinot.core.operator.query.FastFilteredCountOperator;
import org.apache.pinot.core.operator.query.FilteredAggregationOperator;
@@ -77,7 +73,7 @@ public class AggregationPlanNode implements PlanNode {
@Override
public Operator<AggregationResultsBlock> run() {
assert _queryContext.getAggregationFunctions() != null;
- return _queryContext.isHasFilteredAggregations() ?
buildFilteredAggOperator() : buildNonFilteredAggOperator();
+ return _queryContext.hasFilteredAggregations() ?
buildFilteredAggOperator() : buildNonFilteredAggOperator();
}
/**
@@ -86,83 +82,18 @@ public class AggregationPlanNode implements PlanNode {
private FilteredAggregationOperator buildFilteredAggOperator() {
int numTotalDocs = _indexSegment.getSegmentMetadata().getTotalDocs();
// Build the operator chain for the main predicate
- Pair<FilterPlanNode, BaseFilterOperator> filterOperatorPair =
buildFilterOperator(_queryContext.getFilter());
- TransformOperator transformOperator =
buildTransformOperatorForFilteredAggregates(filterOperatorPair.getRight());
-
- return buildFilterOperatorInternal(filterOperatorPair.getRight(),
transformOperator, numTotalDocs);
- }
-
- /**
- * Build a FilteredAggregationOperator given the parameters.
- * @param mainPredicateFilterOperator Filter operator corresponding to the
main predicate
- * @param mainTransformOperator Transform operator corresponding to the main
predicate
- * @param numTotalDocs Number of total docs
- */
- private FilteredAggregationOperator
buildFilterOperatorInternal(BaseFilterOperator mainPredicateFilterOperator,
- TransformOperator mainTransformOperator, int numTotalDocs) {
- Map<FilterContext, Pair<List<AggregationFunction>, TransformOperator>>
filterContextToAggFuncsMap = new HashMap<>();
- List<AggregationFunction> nonFilteredAggregationFunctions = new
ArrayList<>();
- List<Pair<AggregationFunction, FilterContext>> aggregationFunctions =
- _queryContext.getFilteredAggregationFunctions();
-
- // For each aggregation function, check if the aggregation function is a
filtered agg.
- // If it is, populate the corresponding filter operator and corresponding
transform operator
- for (Pair<AggregationFunction, FilterContext> inputPair :
aggregationFunctions) {
- if (inputPair.getLeft() != null) {
- FilterContext currentFilterExpression = inputPair.getRight();
- if (filterContextToAggFuncsMap.get(currentFilterExpression) != null) {
-
filterContextToAggFuncsMap.get(currentFilterExpression).getLeft().add(inputPair.getLeft());
- continue;
- }
- Pair<FilterPlanNode, BaseFilterOperator> pair =
buildFilterOperator(currentFilterExpression);
- BaseFilterOperator wrappedFilterOperator =
- new CombinedFilterOperator(mainPredicateFilterOperator,
pair.getRight(), _queryContext.getQueryOptions());
- TransformOperator newTransformOperator =
buildTransformOperatorForFilteredAggregates(wrappedFilterOperator);
- // For each transform operator, associate it with the underlying
expression. This allows
- // fetching the relevant TransformOperator when resolving blocks
during aggregation
- // execution
- List<AggregationFunction> aggFunctionList = new ArrayList<>();
- aggFunctionList.add(inputPair.getLeft());
- filterContextToAggFuncsMap.put(currentFilterExpression,
Pair.of(aggFunctionList, newTransformOperator));
- } else {
- nonFilteredAggregationFunctions.add(inputPair.getLeft());
- }
- }
- List<Pair<AggregationFunction[], TransformOperator>> aggToTransformOpList
= new ArrayList<>();
- // Convert to array since FilteredAggregationOperator expects it
- for (Pair<List<AggregationFunction>, TransformOperator> pair :
filterContextToAggFuncsMap.values()) {
- List<AggregationFunction> aggregationFunctionList = pair.getLeft();
- if (aggregationFunctionList == null) {
- throw new IllegalStateException("Null aggregation list seen");
- }
- aggToTransformOpList.add(Pair.of(aggregationFunctionList.toArray(new
AggregationFunction[0]), pair.getRight()));
- }
- aggToTransformOpList.add(
- Pair.of(nonFilteredAggregationFunctions.toArray(new
AggregationFunction[0]), mainTransformOperator));
+ Pair<FilterPlanNode, BaseFilterOperator> filterOperatorPair =
+ AggregationFunctionUtils.buildFilterOperator(_indexSegment,
_queryContext);
+ TransformOperator transformOperator =
+
AggregationFunctionUtils.buildTransformOperatorForFilteredAggregates(_indexSegment,
_queryContext,
+ filterOperatorPair.getRight(), null);
+ List<Pair<AggregationFunction[], TransformOperator>> aggToTransformOpList =
+ AggregationFunctionUtils.buildFilteredAggTransformPairs(_indexSegment,
_queryContext,
+ filterOperatorPair.getRight(), transformOperator, null);
return new
FilteredAggregationOperator(_queryContext.getAggregationFunctions(),
aggToTransformOpList, numTotalDocs);
}
- /**
- * Build a filter operator from the given FilterContext.
- *
- * It returns the FilterPlanNode to allow reusing plan level components such
as predicate
- * evaluator map
- */
- private Pair<FilterPlanNode, BaseFilterOperator>
buildFilterOperator(FilterContext filterContext) {
- FilterPlanNode filterPlanNode = new FilterPlanNode(_indexSegment,
_queryContext, filterContext);
- return Pair.of(filterPlanNode, filterPlanNode.run());
- }
-
- private TransformOperator
buildTransformOperatorForFilteredAggregates(BaseFilterOperator filterOperator) {
- AggregationFunction[] aggregationFunctions =
_queryContext.getAggregationFunctions();
- Set<ExpressionContext> expressionsToTransform =
-
AggregationFunctionUtils.collectExpressionsToTransform(aggregationFunctions,
null);
-
- return new TransformPlanNode(_indexSegment, _queryContext,
expressionsToTransform,
- DocIdSetPlanNode.MAX_DOC_PER_CALL, filterOperator).run();
- }
-
/**
* Processing workhorse for non filtered aggregates. Note that this code
path is invoked only
* if the query has no filtered aggregates at all. If a query has mixed
aggregates, filtered
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java
b/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java
index 2b5da7896b..99fdec9746 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java
@@ -21,8 +21,12 @@ package org.apache.pinot.core.plan;
import java.util.List;
import java.util.Map;
import java.util.Set;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.core.common.Operator;
+import org.apache.pinot.core.operator.blocks.results.GroupByResultsBlock;
import org.apache.pinot.core.operator.filter.BaseFilterOperator;
+import org.apache.pinot.core.operator.query.FilteredGroupByOperator;
import org.apache.pinot.core.operator.query.GroupByOperator;
import org.apache.pinot.core.operator.transform.TransformOperator;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
@@ -50,10 +54,34 @@ public class GroupByPlanNode implements PlanNode {
}
@Override
- public GroupByOperator run() {
+ public Operator<GroupByResultsBlock> run() {
assert _queryContext.getAggregationFunctions() != null;
assert _queryContext.getGroupByExpressions() != null;
+ if (_queryContext.hasFilteredAggregations()) {
+ return buildFilteredGroupByPlan();
+ }
+ return buildNonFilteredGroupByPlan();
+ }
+
+ private FilteredGroupByOperator buildFilteredGroupByPlan() {
+ int numTotalDocs = _indexSegment.getSegmentMetadata().getTotalDocs();
+ // Build the operator chain for the main predicate so the filter plan can
be run only one time
+ Pair<FilterPlanNode, BaseFilterOperator> filterOperatorPair =
+ AggregationFunctionUtils.buildFilterOperator(_indexSegment,
_queryContext);
+ ExpressionContext[] groupByExpressions =
_queryContext.getGroupByExpressions().toArray(new ExpressionContext[0]);
+ TransformOperator transformOperator =
+
AggregationFunctionUtils.buildTransformOperatorForFilteredAggregates(_indexSegment,
_queryContext,
+ filterOperatorPair.getRight(), groupByExpressions);
+
+ List<Pair<AggregationFunction[], TransformOperator>> aggToTransformOpList =
+ AggregationFunctionUtils.buildFilteredAggTransformPairs(_indexSegment,
_queryContext,
+ filterOperatorPair.getRight(), transformOperator,
groupByExpressions);
+ return new
FilteredGroupByOperator(_queryContext.getAggregationFunctions(),
aggToTransformOpList,
+ _queryContext.getGroupByExpressions().toArray(new
ExpressionContext[0]), numTotalDocs, _queryContext);
+ }
+
+ private GroupByOperator buildNonFilteredGroupByPlan() {
int numTotalDocs = _indexSegment.getSegmentMetadata().getTotalDocs();
AggregationFunction[] aggregationFunctions =
_queryContext.getAggregationFunctions();
ExpressionContext[] groupByExpressions =
_queryContext.getGroupByExpressions().toArray(new ExpressionContext[0]);
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
index 8ef21fa1b4..0dcecb046d 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
@@ -18,6 +18,7 @@
*/
package org.apache.pinot.core.query.aggregation.function;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
@@ -26,13 +27,23 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nullable;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.pinot.common.datatable.DataTable;
import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.common.request.context.FilterContext;
import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.common.ObjectSerDeUtils;
import org.apache.pinot.core.operator.blocks.TransformBlock;
+import org.apache.pinot.core.operator.filter.BaseFilterOperator;
+import org.apache.pinot.core.operator.filter.CombinedFilterOperator;
+import org.apache.pinot.core.operator.transform.TransformOperator;
+import org.apache.pinot.core.plan.DocIdSetPlanNode;
+import org.apache.pinot.core.plan.FilterPlanNode;
+import org.apache.pinot.core.plan.TransformPlanNode;
+import org.apache.pinot.core.query.request.context.QueryContext;
import org.apache.pinot.segment.spi.AggregationFunctionType;
+import org.apache.pinot.segment.spi.IndexSegment;
import
org.apache.pinot.segment.spi.index.startree.AggregationFunctionColumnPair;
@@ -165,4 +176,87 @@ public class AggregationFunctionUtils {
throw new IllegalStateException("Illegal column data type in final
result: " + columnDataType);
}
}
+
+ /**
+ * Build a filter operator from the given FilterContext.
+ *
+ * It returns the FilterPlanNode to allow reusing plan level components such
as predicate
+ * evaluator map
+ */
+ public static Pair<FilterPlanNode, BaseFilterOperator>
buildFilterOperator(IndexSegment indexSegment,
+ QueryContext queryContext, FilterContext filterContext) {
+ FilterPlanNode filterPlanNode = new FilterPlanNode(indexSegment,
queryContext, filterContext);
+ return Pair.of(filterPlanNode, filterPlanNode.run());
+ }
+
+ public static Pair<FilterPlanNode, BaseFilterOperator>
buildFilterOperator(IndexSegment indexSegment,
+ QueryContext queryContext) {
+ return buildFilterOperator(indexSegment, queryContext,
queryContext.getFilter());
+ }
+
+ public static TransformOperator
buildTransformOperatorForFilteredAggregates(IndexSegment indexSegment,
+ QueryContext queryContext, BaseFilterOperator filterOperator, @Nullable
ExpressionContext[] groupByExpressions) {
+ AggregationFunction[] aggregationFunctions =
queryContext.getAggregationFunctions();
+ assert aggregationFunctions != null;
+ Set<ExpressionContext> expressionsToTransform =
+ collectExpressionsToTransform(aggregationFunctions,
groupByExpressions);
+ return new TransformPlanNode(indexSegment, queryContext,
expressionsToTransform, DocIdSetPlanNode.MAX_DOC_PER_CALL,
+ filterOperator).run();
+ }
+
+ /**
+ * Build pairs of filtered aggregation functions and corresponding transform
operator
+ * @param mainPredicateFilterOperator Filter operator corresponding to the
main predicate
+ * @param mainTransformOperator Transform operator corresponding to the main
predicate
+ */
+ public static List<Pair<AggregationFunction[], TransformOperator>>
buildFilteredAggTransformPairs(
+ IndexSegment indexSegment, QueryContext queryContext, BaseFilterOperator
mainPredicateFilterOperator,
+ TransformOperator mainTransformOperator, @Nullable ExpressionContext[]
groupByExpressions) {
+ Map<FilterContext, Pair<List<AggregationFunction>, TransformOperator>>
filterContextToAggFuncsMap = new HashMap<>();
+ List<AggregationFunction> nonFilteredAggregationFunctions = new
ArrayList<>();
+ List<Pair<AggregationFunction, FilterContext>> aggregationFunctions =
+ queryContext.getFilteredAggregationFunctions();
+ List<Pair<AggregationFunction[], TransformOperator>> aggToTransformOpList
= new ArrayList<>();
+
+ // For each aggregation function, check if the aggregation function is a
filtered agg.
+ // If it is, populate the corresponding filter operator and corresponding
transform operator
+ assert aggregationFunctions != null;
+ for (Pair<AggregationFunction, FilterContext> inputPair :
aggregationFunctions) {
+ if (inputPair.getLeft() != null) {
+ FilterContext currentFilterExpression = inputPair.getRight();
+ if (filterContextToAggFuncsMap.get(currentFilterExpression) != null) {
+
filterContextToAggFuncsMap.get(currentFilterExpression).getLeft().add(inputPair.getLeft());
+ continue;
+ }
+ Pair<FilterPlanNode, BaseFilterOperator> filterPlanOpPair =
+ buildFilterOperator(indexSegment, queryContext,
currentFilterExpression);
+ BaseFilterOperator wrappedFilterOperator =
+ new CombinedFilterOperator(mainPredicateFilterOperator,
filterPlanOpPair.getRight(),
+ queryContext.getQueryOptions());
+ TransformOperator newTransformOperator =
+ buildTransformOperatorForFilteredAggregates(indexSegment,
queryContext, wrappedFilterOperator,
+ groupByExpressions);
+ // For each transform operator, associate it with the underlying
expression. This allows
+ // fetching the relevant TransformOperator when resolving blocks
during aggregation
+ // execution
+ List<AggregationFunction> aggFunctionList = new ArrayList<>();
+ aggFunctionList.add(inputPair.getLeft());
+ filterContextToAggFuncsMap.put(currentFilterExpression,
Pair.of(aggFunctionList, newTransformOperator));
+ } else {
+ nonFilteredAggregationFunctions.add(inputPair.getLeft());
+ }
+ }
+ // Convert to array since FilteredGroupByOperator expects it
+ for (Pair<List<AggregationFunction>, TransformOperator> pair :
filterContextToAggFuncsMap.values()) {
+ List<AggregationFunction> aggregationFunctionList = pair.getLeft();
+ if (aggregationFunctionList == null) {
+ throw new IllegalStateException("Null aggregation list seen");
+ }
+ aggToTransformOpList.add(Pair.of(aggregationFunctionList.toArray(new
AggregationFunction[0]), pair.getRight()));
+ }
+ aggToTransformOpList.add(
+ Pair.of(nonFilteredAggregationFunctions.toArray(new
AggregationFunction[0]), mainTransformOperator));
+
+ return aggToTransformOpList;
+ }
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/DefaultGroupByExecutor.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/DefaultGroupByExecutor.java
index e0af94070c..38ebd3706c 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/DefaultGroupByExecutor.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/DefaultGroupByExecutor.java
@@ -20,6 +20,7 @@ package org.apache.pinot.core.query.aggregation.groupby;
import java.util.Collection;
import java.util.Map;
+import javax.annotation.Nullable;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.data.table.IntermediateRecord;
@@ -58,16 +59,28 @@ public class DefaultGroupByExecutor implements
GroupByExecutor {
protected final int[] _svGroupKeys;
protected final int[][] _mvGroupKeys;
+ public DefaultGroupByExecutor(QueryContext queryContext, ExpressionContext[]
groupByExpressions,
+ TransformOperator transformOperator) {
+ this(queryContext, queryContext.getAggregationFunctions(),
groupByExpressions, transformOperator, null);
+ }
+
+ public DefaultGroupByExecutor(QueryContext queryContext,
AggregationFunction[] aggregationFunctions,
+ ExpressionContext[] groupByExpressions, TransformOperator
transformOperator) {
+ this(queryContext, aggregationFunctions, groupByExpressions,
transformOperator, null);
+ }
+
/**
* Constructor for the class.
*
* @param queryContext Query context
+ * @param aggregationFunctions Aggregation functions
* @param groupByExpressions Array of group-by expressions
* @param transformOperator Transform operator
*/
- public DefaultGroupByExecutor(QueryContext queryContext, ExpressionContext[]
groupByExpressions,
- TransformOperator transformOperator) {
- _aggregationFunctions = queryContext.getAggregationFunctions();
+ public DefaultGroupByExecutor(QueryContext queryContext,
AggregationFunction[] aggregationFunctions,
+ ExpressionContext[] groupByExpressions, TransformOperator
transformOperator,
+ @Nullable GroupKeyGenerator groupKeyGenerator) {
+ _aggregationFunctions = aggregationFunctions;
assert _aggregationFunctions != null;
_nullHandlingEnabled = queryContext.isNullHandlingEnabled();
@@ -83,19 +96,23 @@ public class DefaultGroupByExecutor implements
GroupByExecutor {
// Initialize group key generator
int numGroupsLimit = queryContext.getNumGroupsLimit();
int maxInitialResultHolderCapacity =
queryContext.getMaxInitialResultHolderCapacity();
- if (hasNoDictionaryGroupByExpression || _nullHandlingEnabled) {
- if (groupByExpressions.length == 1) {
- // TODO(nhejazi): support MV and dictionary based when null handling
is enabled.
- _groupKeyGenerator =
- new NoDictionarySingleColumnGroupKeyGenerator(transformOperator,
groupByExpressions[0], numGroupsLimit,
- _nullHandlingEnabled);
+ if (groupKeyGenerator != null) {
+ _groupKeyGenerator = groupKeyGenerator;
+ } else {
+ if (hasNoDictionaryGroupByExpression || _nullHandlingEnabled) {
+ if (groupByExpressions.length == 1) {
+ // TODO(nhejazi): support MV and dictionary based when null handling
is enabled.
+ _groupKeyGenerator =
+ new NoDictionarySingleColumnGroupKeyGenerator(transformOperator,
groupByExpressions[0], numGroupsLimit,
+ _nullHandlingEnabled);
+ } else {
+ _groupKeyGenerator =
+ new NoDictionaryMultiColumnGroupKeyGenerator(transformOperator,
groupByExpressions, numGroupsLimit);
+ }
} else {
- _groupKeyGenerator =
- new NoDictionaryMultiColumnGroupKeyGenerator(transformOperator,
groupByExpressions, numGroupsLimit);
+ _groupKeyGenerator = new
DictionaryBasedGroupKeyGenerator(transformOperator, groupByExpressions,
numGroupsLimit,
+ maxInitialResultHolderCapacity);
}
- } else {
- _groupKeyGenerator = new
DictionaryBasedGroupKeyGenerator(transformOperator, groupByExpressions,
numGroupsLimit,
- maxInitialResultHolderCapacity);
}
// Initialize result holders
@@ -141,7 +158,6 @@ public class DefaultGroupByExecutor implements
GroupByExecutor {
AggregationFunction aggregationFunction =
_aggregationFunctions[functionIndex];
Map<ExpressionContext, BlockValSet> blockValSetMap =
AggregationFunctionUtils.getBlockValSetMap(aggregationFunction,
transformBlock);
-
GroupByResultHolder groupByResultHolder =
_groupByResultHolders[functionIndex];
if (_hasMVGroupByExpression) {
aggregationFunction.aggregateGroupByMV(length, _mvGroupKeys,
groupByResultHolder, blockValSetMap);
@@ -164,4 +180,14 @@ public class DefaultGroupByExecutor implements
GroupByExecutor {
public Collection<IntermediateRecord> trimGroupByResult(int trimSize,
TableResizer tableResizer) {
return tableResizer.trimInSegmentResults(_groupKeyGenerator,
_groupByResultHolders, trimSize);
}
+
+ @Override
+ public GroupKeyGenerator getGroupKeyGenerator() {
+ return _groupKeyGenerator;
+ }
+
+ @Override
+ public GroupByResultHolder[] getGroupByResultHolders() {
+ return _groupByResultHolders;
+ }
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/GroupByExecutor.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/GroupByExecutor.java
index 869ef5dbe9..db5ff16b18 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/GroupByExecutor.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/GroupByExecutor.java
@@ -58,4 +58,8 @@ public interface GroupByExecutor {
*
*/
Collection<IntermediateRecord> trimGroupByResult(int trimSize, TableResizer
tableResizer);
+
+ GroupKeyGenerator getGroupKeyGenerator();
+
+ GroupByResultHolder[] getGroupByResultHolders();
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java
index 5c1bd2fe84..fcc97dd6fd 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java
@@ -260,7 +260,7 @@ public class QueryContext {
/**
* Returns the filtered aggregation expressions for the query.
*/
- public boolean isHasFilteredAggregations() {
+ public boolean hasFilteredAggregations() {
return _hasFilteredAggregations;
}
@@ -536,10 +536,6 @@ public class QueryContext {
FunctionContext aggregation = pair.getLeft();
FilterContext filter = pair.getRight();
if (filter != null) {
- // Filtered aggregation
- if (_groupByExpressions != null) {
- throw new IllegalStateException("GROUP BY with FILTER clauses is
not supported");
- }
queryContext._hasFilteredAggregations = true;
}
int functionIndex = filteredAggregationFunctions.size();
diff --git
a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/groupby/GroupByTrimTest.java
b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/groupby/GroupByTrimTest.java
index 62236f3a4b..dba3faefe6 100644
---
a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/groupby/GroupByTrimTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/groupby/GroupByTrimTest.java
@@ -29,11 +29,11 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.tuple.Pair;
+import org.apache.pinot.core.common.Operator;
import org.apache.pinot.core.data.table.Record;
import org.apache.pinot.core.data.table.Table;
import org.apache.pinot.core.operator.blocks.results.GroupByResultsBlock;
import org.apache.pinot.core.operator.combine.GroupByCombineOperator;
-import org.apache.pinot.core.operator.query.GroupByOperator;
import org.apache.pinot.core.plan.GroupByPlanNode;
import org.apache.pinot.core.query.request.context.QueryContext;
import
org.apache.pinot.core.query.request.context.utils.QueryContextConverterUtils;
@@ -50,13 +50,12 @@ import org.apache.pinot.spi.data.readers.GenericRow;
import org.apache.pinot.spi.utils.CommonConstants;
import org.apache.pinot.spi.utils.ReadMode;
import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
+import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
-import static org.testng.Assert.assertEquals;
-
/**
* Unit test for GroupBy Trim functionalities.
@@ -120,7 +119,7 @@ public class GroupByTrimTest {
queryContext.setMinServerGroupTrimSize(minServerGroupTrimSize);
// Create a query operator
- GroupByOperator groupByOperator = new GroupByPlanNode(_indexSegment,
queryContext).run();
+ Operator<GroupByResultsBlock> groupByOperator = new
GroupByPlanNode(_indexSegment, queryContext).run();
GroupByCombineOperator combineOperator =
new GroupByCombineOperator(Collections.singletonList(groupByOperator),
queryContext, _executorService);
@@ -130,7 +129,7 @@ public class GroupByTrimTest {
// Extract the execution result
List<Pair<Double, Double>> extractedResult =
extractTestResult(resultsBlock.getTable());
- assertEquals(extractedResult, expectedResult);
+ Assert.assertEquals(extractedResult, expectedResult);
}
/**
diff --git
a/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java
b/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java
index 2fc9ad1fa6..9d772abc3f 100644
---
a/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java
@@ -202,10 +202,10 @@ public class FilteredAggregationsTest extends
BaseQueriesTest {
nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE BOOLEAN_COL=true
AND STARTSWITH(STRING_COL, 'abc')";
testQuery(filterQuery, nonFilterQuery);
- filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE BOOLEAN_COL AND
STARTSWITH(REVERSE(STRING_COL), 'abc')) FROM "
- + "MyTable";
- nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE BOOLEAN_COL=true
AND STARTSWITH(REVERSE(STRING_COL), "
- + "'abc')";
+ filterQuery =
+ "SELECT SUM(INT_COL) FILTER(WHERE BOOLEAN_COL AND
STARTSWITH(REVERSE(STRING_COL), 'abc')) FROM " + "MyTable";
+ nonFilterQuery =
+ "SELECT SUM(INT_COL) FROM MyTable WHERE BOOLEAN_COL=true AND
STARTSWITH(REVERSE(STRING_COL), " + "'abc')";
testQuery(filterQuery, nonFilterQuery);
}
@@ -335,10 +335,49 @@ public class FilteredAggregationsTest extends
BaseQueriesTest {
testQuery(filterQuery, nonFilterQuery);
}
- @Test(expectedExceptions = IllegalStateException.class)
- public void testGroupBySupport() {
- String filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 2),
MAX(INT_COL) FILTER(WHERE INT_COL > 2) "
- + "FROM MyTable WHERE INT_COL < 1000 GROUP BY INT_COL";
- getBrokerResponse(filterQuery);
+ @Test
+ public void testGroupBy() {
+ String filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 25000)
FROM MyTable GROUP BY BOOLEAN_COL";
+ String nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE INT_COL >
25000 GROUP BY BOOLEAN_COL";
+ testQuery(filterQuery, nonFilterQuery);
+ }
+
+ @Test
+ public void testGroupByCaseAlternative() {
+ String filterQuery =
+ "SELECT SUM(INT_COL), SUM(INT_COL) FILTER(WHERE INT_COL > 25000) AS
total_sum FROM MyTable GROUP BY "
+ + "BOOLEAN_COL";
+ String nonFilterQuery =
+ "SELECT SUM(INT_COL), SUM(CASE WHEN INT_COL > 25000 THEN INT_COL ELSE
0 END) AS total_sum FROM MyTable GROUP "
+ + "BY BOOLEAN_COL";
+ testQuery(filterQuery, nonFilterQuery);
+ }
+
+ @Test
+ public void testGroupBySameFilter() {
+ String filterQuery =
+ "SELECT AVG(INT_COL) FILTER(WHERE INT_COL > 25000), SUM(INT_COL)
FILTER(WHERE INT_COL > 25000) FROM MyTable "
+ + "GROUP BY BOOLEAN_COL";
+ String nonFilterQuery = "SELECT AVG(INT_COL), SUM(INT_COL) FROM MyTable
WHERE INT_COL > 25000 GROUP BY BOOLEAN_COL";
+ testQuery(filterQuery, nonFilterQuery);
+ }
+
+ @Test
+ public void testMultipleAggregationsOnSameFilterGroupBy() {
+ String filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL >
29990), "
+ + "MAX(INT_COL) FILTER(WHERE INT_COL > 29990) FROM MyTable GROUP BY
BOOLEAN_COL";
+ String nonFilterQuery = "SELECT MIN(INT_COL), MAX(INT_COL) FROM MyTable
WHERE INT_COL > 29990 GROUP BY BOOLEAN_COL";
+ testQuery(filterQuery, nonFilterQuery);
+
+ filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 29990) AS
total_min, "
+ + "MAX(INT_COL) FILTER(WHERE INT_COL > 29990) AS total_max, "
+ + "SUM(INT_COL) FILTER(WHERE NO_INDEX_COL < 5000) AS total_sum, "
+ + "MAX(NO_INDEX_COL) FILTER(WHERE NO_INDEX_COL < 5000) AS total_max2
FROM MyTable GROUP BY BOOLEAN_COL";
+ nonFilterQuery = "SELECT MIN(CASE WHEN (NO_INDEX_COL > 29990) THEN INT_COL
ELSE 99999 END) AS total_min, "
+ + "MAX(CASE WHEN (INT_COL > 29990) THEN INT_COL ELSE 0 END) AS
total_max, "
+ + "SUM(CASE WHEN (NO_INDEX_COL < 5000) THEN INT_COL ELSE 0 END) AS
total_sum, "
+ + "MAX(CASE WHEN (NO_INDEX_COL < 5000) THEN NO_INDEX_COL ELSE 0 END)
AS total_max2 FROM MyTable GROUP BY "
+ + "BOOLEAN_COL";
+ testQuery(filterQuery, nonFilterQuery);
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]