This is an automated email from the ASF dual-hosted git repository.
jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 36307cb501 Fixes filtered agg result column naming and filtered agg
order-by compat (#10092)
36307cb501 is described below
commit 36307cb5019836c755a2422fa63ef380839eba58
Author: Evan Galpin <[email protected]>
AuthorDate: Fri Jan 13 12:07:15 2023 -0800
Fixes filtered agg result column naming and filtered agg order-by compat
(#10092)
---
.../apache/pinot/core/data/table/TableResizer.java | 21 ++++
.../operator/blocks/results/ResultsBlockUtils.java | 20 ++-
.../operator/query/FilteredGroupByOperator.java | 13 +-
.../apache/pinot/core/plan/GroupByPlanNode.java | 3 +-
.../function/AggregationFunctionUtils.java | 8 ++
.../query/reduce/AggregationDataTableReducer.java | 17 ++-
.../pinot/queries/FilteredAggregationsTest.java | 134 +++++++++++++++------
...terSegmentAggregationMultiValueQueriesTest.java | 3 +-
...SegmentAggregationMultiValueRawQueriesTest.java | 4 +-
9 files changed, 173 insertions(+), 50 deletions(-)
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java
b/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java
index 7f6704fd7a..cbbe6abdce 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java
@@ -28,9 +28,12 @@ import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.common.request.context.FilterContext;
import org.apache.pinot.common.request.context.FunctionContext;
import org.apache.pinot.common.request.context.OrderByExpressionContext;
+import org.apache.pinot.common.request.context.RequestContextUtils;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
@@ -51,6 +54,8 @@ public class TableResizer {
private final Map<ExpressionContext, Integer> _groupByExpressionIndexMap;
private final AggregationFunction[] _aggregationFunctions;
private final Map<FunctionContext, Integer> _aggregationFunctionIndexMap;
+ private final Map<Pair<FunctionContext, FilterContext>, Integer>
_filteredAggregationIndexMap;
+ private final List<Pair<AggregationFunction, FilterContext>>
_filteredAggregationFunctions;
private final int _numOrderByExpressions;
private final OrderByValueExtractor[] _orderByValueExtractors;
private final Comparator<IntermediateRecord> _intermediateRecordComparator;
@@ -73,6 +78,8 @@ public class TableResizer {
assert _aggregationFunctions != null;
_aggregationFunctionIndexMap =
queryContext.getAggregationFunctionIndexMap();
assert _aggregationFunctionIndexMap != null;
+ _filteredAggregationIndexMap =
queryContext.getFilteredAggregationsIndexMap();
+ _filteredAggregationFunctions =
queryContext.getFilteredAggregationFunctions();
List<OrderByExpressionContext> orderByExpressions =
queryContext.getOrderByExpressions();
assert orderByExpressions != null;
@@ -137,6 +144,15 @@ public class TableResizer {
if (function.getType() == FunctionContext.Type.AGGREGATION) {
// Aggregation function
return new
AggregationFunctionExtractor(_aggregationFunctionIndexMap.get(function));
+ } else if (function.getType() == FunctionContext.Type.TRANSFORM
+ && "FILTER".equalsIgnoreCase(function.getFunctionName())) {
+ FunctionContext aggregation =
function.getArguments().get(0).getFunction();
+ ExpressionContext filterExpression = function.getArguments().get(1);
+ FilterContext filter = RequestContextUtils.getFilter(filterExpression);
+
+ int functionIndex =
_filteredAggregationIndexMap.get(Pair.of(aggregation, filter));
+ AggregationFunction aggregationFunction =
_filteredAggregationFunctions.get(functionIndex).getLeft();
+ return new AggregationFunctionExtractor(functionIndex,
aggregationFunction);
} else {
// Post-aggregation function
return new PostAggregationFunctionExtractor(function);
@@ -414,6 +430,11 @@ public class TableResizer {
_aggregationFunction = _aggregationFunctions[aggregationFunctionIndex];
}
+ AggregationFunctionExtractor(int aggregationFunctionIndex,
AggregationFunction aggregationFunction) {
+ _index = aggregationFunctionIndex + _numGroupByExpressions;
+ _aggregationFunction = aggregationFunction;
+ }
+
@Override
public ColumnDataType getValueType() {
return _aggregationFunction.getFinalResultColumnType();
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtils.java
b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtils.java
index 5f5e7d0769..6fe8346a8b 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtils.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtils.java
@@ -22,10 +22,13 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.common.request.context.FilterContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
+import
org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
import
org.apache.pinot.core.query.aggregation.function.DistinctAggregationFunction;
import org.apache.pinot.core.query.distinct.DistinctTable;
import org.apache.pinot.core.query.request.context.QueryContext;
@@ -68,6 +71,8 @@ public class ResultsBlockUtils {
private static AggregationResultsBlock
buildEmptyAggregationQueryResults(QueryContext queryContext) {
AggregationFunction[] aggregationFunctions =
queryContext.getAggregationFunctions();
+ List<Pair<AggregationFunction, FilterContext>>
filteredAggregationFunctions =
+ queryContext.getFilteredAggregationFunctions();
assert aggregationFunctions != null;
int numAggregations = aggregationFunctions.length;
List<Object> results = new ArrayList<>(numAggregations);
@@ -78,12 +83,12 @@ public class ResultsBlockUtils {
}
private static GroupByResultsBlock
buildEmptyGroupByQueryResults(QueryContext queryContext) {
- AggregationFunction[] aggregationFunctions =
queryContext.getAggregationFunctions();
- assert aggregationFunctions != null;
- int numAggregations = aggregationFunctions.length;
+ List<Pair<AggregationFunction, FilterContext>>
filteredAggregationFunctions =
+ queryContext.getFilteredAggregationFunctions();
+
List<ExpressionContext> groupByExpressions =
queryContext.getGroupByExpressions();
assert groupByExpressions != null;
- int numColumns = groupByExpressions.size() + numAggregations;
+ int numColumns = groupByExpressions.size() +
filteredAggregationFunctions.size();
String[] columnNames = new String[numColumns];
ColumnDataType[] columnDataTypes = new ColumnDataType[numColumns];
int index = 0;
@@ -93,9 +98,12 @@ public class ResultsBlockUtils {
columnDataTypes[index] = ColumnDataType.STRING;
index++;
}
- for (AggregationFunction aggregationFunction : aggregationFunctions) {
+ for (Pair<AggregationFunction, FilterContext> aggFilterPair :
filteredAggregationFunctions) {
// NOTE: Use AggregationFunction.getResultColumnName() for SQL format
response
- columnNames[index] = aggregationFunction.getResultColumnName();
+ AggregationFunction aggregationFunction = aggFilterPair.getLeft();
+ String columnName =
+ AggregationFunctionUtils.getResultColumnName(aggregationFunction,
aggFilterPair.getRight());
+ columnNames[index] = columnName;
columnDataTypes[index] =
aggregationFunction.getIntermediateResultColumnType();
index++;
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java
index e895d817dd..872a999f54 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java
@@ -24,6 +24,7 @@ import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.common.request.context.FilterContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.Operator;
import org.apache.pinot.core.data.table.IntermediateRecord;
@@ -34,6 +35,7 @@ import org.apache.pinot.core.operator.blocks.TransformBlock;
import org.apache.pinot.core.operator.blocks.results.GroupByResultsBlock;
import org.apache.pinot.core.operator.transform.TransformOperator;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
+import
org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
import
org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult;
import org.apache.pinot.core.query.aggregation.groupby.DefaultGroupByExecutor;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
@@ -62,6 +64,7 @@ public class FilteredGroupByOperator extends
BaseOperator<GroupByResultsBlock> {
private final QueryContext _queryContext;
public FilteredGroupByOperator(AggregationFunction[] aggregationFunctions,
+ List<Pair<AggregationFunction, FilterContext>>
filteredAggregationFunctions,
List<Pair<AggregationFunction[], TransformOperator>>
aggFunctionsWithTransformOperator,
ExpressionContext[] groupByExpressions, long numTotalDocs, QueryContext
queryContext) {
_aggregationFunctions = aggregationFunctions;
@@ -87,9 +90,12 @@ public class FilteredGroupByOperator extends
BaseOperator<GroupByResultsBlock> {
// Extract column names and data types for aggregation functions
for (int i = 0; i < numAggregationFunctions; i++) {
- AggregationFunction aggregationFunction = aggregationFunctions[i];
int index = numGroupByExpressions + i;
- columnNames[index] = aggregationFunction.getResultColumnName();
+ Pair<AggregationFunction, FilterContext> filteredAggPair =
filteredAggregationFunctions.get(i);
+ AggregationFunction aggregationFunction = filteredAggPair.getLeft();
+ String columnName =
+ AggregationFunctionUtils.getResultColumnName(aggregationFunction,
filteredAggPair.getRight());
+ columnNames[index] = columnName;
columnDataTypes[index] =
aggregationFunction.getIntermediateResultColumnType();
}
@@ -102,7 +108,8 @@ public class FilteredGroupByOperator extends
BaseOperator<GroupByResultsBlock> {
int numAggregations = _aggregationFunctions.length;
GroupByResultHolder[] groupByResultHolders = new
GroupByResultHolder[numAggregations];
- IdentityHashMap<AggregationFunction, Integer> resultHolderIndexMap = new
IdentityHashMap<>(numAggregations);
+ IdentityHashMap<AggregationFunction, Integer> resultHolderIndexMap =
+ new IdentityHashMap<>(_aggregationFunctions.length);
for (int i = 0; i < numAggregations; i++) {
resultHolderIndexMap.put(_aggregationFunctions[i], i);
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java
b/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java
index 99fdec9746..ccb51143e6 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java
@@ -77,7 +77,8 @@ public class GroupByPlanNode implements PlanNode {
List<Pair<AggregationFunction[], TransformOperator>> aggToTransformOpList =
AggregationFunctionUtils.buildFilteredAggTransformPairs(_indexSegment,
_queryContext,
filterOperatorPair.getRight(), transformOperator,
groupByExpressions);
- return new
FilteredGroupByOperator(_queryContext.getAggregationFunctions(),
aggToTransformOpList,
+ return new FilteredGroupByOperator(_queryContext.getAggregationFunctions(),
+ _queryContext.getFilteredAggregationFunctions(), aggToTransformOpList,
_queryContext.getGroupByExpressions().toArray(new
ExpressionContext[0]), numTotalDocs, _queryContext);
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
index 0dcecb046d..6b1dd21e3c 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
@@ -259,4 +259,12 @@ public class AggregationFunctionUtils {
return aggToTransformOpList;
}
+
+ public static String getResultColumnName(AggregationFunction
aggregationFunction, @Nullable FilterContext filter) {
+ String columnName = aggregationFunction.getResultColumnName();
+ if (filter != null) {
+ columnName += " FILTER(WHERE " + filter + ")";
+ }
+ return columnName;
+ }
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
index b727df9c30..739c1f691e 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
@@ -21,9 +21,12 @@ package org.apache.pinot.core.query.reduce;
import com.google.common.base.Preconditions;
import java.util.Collection;
import java.util.Collections;
+import java.util.List;
import java.util.Map;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.pinot.common.datatable.DataTable;
import org.apache.pinot.common.metrics.BrokerMetrics;
+import org.apache.pinot.common.request.context.FilterContext;
import org.apache.pinot.common.response.broker.BrokerResponseNative;
import org.apache.pinot.common.response.broker.ResultTable;
import org.apache.pinot.common.utils.DataSchema;
@@ -42,10 +45,12 @@ import org.roaringbitmap.RoaringBitmap;
public class AggregationDataTableReducer implements DataTableReducer {
private final QueryContext _queryContext;
private final AggregationFunction[] _aggregationFunctions;
+ private final List<Pair<AggregationFunction, FilterContext>>
_filteredAggregationFunctions;
AggregationDataTableReducer(QueryContext queryContext) {
_queryContext = queryContext;
_aggregationFunctions = queryContext.getAggregationFunctions();
+ _filteredAggregationFunctions =
queryContext.getFilteredAggregationFunctions();
}
/**
@@ -150,11 +155,17 @@ public class AggregationDataTableReducer implements
DataTableReducer {
int numAggregationFunctions = _aggregationFunctions.length;
String[] columnNames = new String[numAggregationFunctions];
ColumnDataType[] columnDataTypes = new
ColumnDataType[numAggregationFunctions];
- for (int i = 0; i < numAggregationFunctions; i++) {
- AggregationFunction aggregationFunction = _aggregationFunctions[i];
- columnNames[i] = aggregationFunction.getResultColumnName();
+
+ int i = 0;
+ for (Pair<AggregationFunction, FilterContext> aggFilterPair :
_filteredAggregationFunctions) {
+ AggregationFunction aggregationFunction = aggFilterPair.getLeft();
+ String columnName =
+ AggregationFunctionUtils.getResultColumnName(aggregationFunction,
aggFilterPair.getRight());
+ columnNames[i] = columnName;
columnDataTypes[i] = aggregationFunction.getFinalResultColumnType();
+ i++;
}
+
return new DataSchema(columnNames, columnDataTypes);
}
}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java
b/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java
index 9d772abc3f..2ea664ec67 100644
---
a/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java
@@ -161,51 +161,84 @@ public class FilteredAggregationsTest extends
BaseQueriesTest {
@Test
public void testSimpleQueries() {
- String filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 9999)
FROM MyTable WHERE INT_COL < 1000000";
- String nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE INT_COL >
9999 AND INT_COL < 1000000";
+ String filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 9999)
sum1 FROM MyTable WHERE INT_COL < 1000000";
+ String nonFilterQuery = "SELECT SUM(INT_COL) sum1 FROM MyTable WHERE
INT_COL > 9999 AND INT_COL < 1000000";
testQuery(filterQuery, nonFilterQuery);
- filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL < 3) FROM MyTable
WHERE INT_COL > 1";
- nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE INT_COL > 1 AND
INT_COL < 3";
+ filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL < 3) sum1 FROM
MyTable WHERE INT_COL > 1";
+ nonFilterQuery = "SELECT SUM(INT_COL) sum1 FROM MyTable WHERE INT_COL > 1
AND INT_COL < 3";
testQuery(filterQuery, nonFilterQuery);
- filterQuery = "SELECT COUNT(*) FILTER(WHERE INT_COL = 4) FROM MyTable";
- nonFilterQuery = "SELECT COUNT(*) FROM MyTable WHERE INT_COL = 4";
+ filterQuery = "SELECT COUNT(*) FILTER(WHERE INT_COL = 4) count1 FROM
MyTable";
+ nonFilterQuery = "SELECT COUNT(*) count1 FROM MyTable WHERE INT_COL = 4";
testQuery(filterQuery, nonFilterQuery);
- filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 8000) FROM
MyTable ";
- nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE INT_COL > 8000";
+ filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 8000) sum1 FROM
MyTable ";
+ nonFilterQuery = "SELECT SUM(INT_COL) sum1 FROM MyTable WHERE INT_COL >
8000";
testQuery(filterQuery, nonFilterQuery);
- filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE NO_INDEX_COL <= 1) FROM
MyTable WHERE INT_COL > 1";
- nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE NO_INDEX_COL <= 1
AND INT_COL > 1";
+ filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE NO_INDEX_COL <= 1) sum1
FROM MyTable WHERE INT_COL > 1";
+ nonFilterQuery = "SELECT SUM(INT_COL) sum1 FROM MyTable WHERE NO_INDEX_COL
<= 1 AND INT_COL > 1";
testQuery(filterQuery, nonFilterQuery);
- filterQuery = "SELECT AVG(NO_INDEX_COL) FROM MyTable WHERE NO_INDEX_COL >
-1";
- nonFilterQuery = "SELECT AVG(NO_INDEX_COL) FROM MyTable";
+ filterQuery = "SELECT AVG(NO_INDEX_COL) avg1 FROM MyTable WHERE
NO_INDEX_COL > -1";
+ nonFilterQuery = "SELECT AVG(NO_INDEX_COL) avg1 FROM MyTable";
testQuery(filterQuery, nonFilterQuery);
- filterQuery = "SELECT AVG(INT_COL) FILTER(WHERE NO_INDEX_COL > -1) FROM
MyTable";
- nonFilterQuery = "SELECT AVG(INT_COL) FROM MyTable";
+ filterQuery = "SELECT AVG(INT_COL) FILTER(WHERE NO_INDEX_COL > -1) avg1
FROM MyTable";
+ nonFilterQuery = "SELECT AVG(INT_COL) avg1 FROM MyTable";
testQuery(filterQuery, nonFilterQuery);
- filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 29990),
MAX(INT_COL) FILTER(WHERE INT_COL > 29990) "
- + "FROM MyTable";
- nonFilterQuery = "SELECT MIN(INT_COL), MAX(INT_COL) FROM MyTable WHERE
INT_COL > 29990";
+ filterQuery =
+ "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 29990) min1,
MAX(INT_COL) FILTER(WHERE INT_COL > 29990) max1"
+ + " FROM MyTable";
+ nonFilterQuery = "SELECT MIN(INT_COL) min1, MAX(INT_COL) max1 FROM MyTable
WHERE INT_COL > 29990";
testQuery(filterQuery, nonFilterQuery);
- filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE BOOLEAN_COL) FROM MyTable";
- nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE BOOLEAN_COL=true";
+ filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE BOOLEAN_COL) sum1 FROM
MyTable";
+ nonFilterQuery = "SELECT SUM(INT_COL) sum1 FROM MyTable WHERE
BOOLEAN_COL=true";
testQuery(filterQuery, nonFilterQuery);
- filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE BOOLEAN_COL AND
STARTSWITH(STRING_COL, 'abc')) FROM MyTable";
- nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE BOOLEAN_COL=true
AND STARTSWITH(STRING_COL, 'abc')";
+ filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE BOOLEAN_COL AND
STARTSWITH(STRING_COL, 'abc')) sum1 FROM MyTable";
+ nonFilterQuery = "SELECT SUM(INT_COL) sum1 FROM MyTable WHERE
BOOLEAN_COL=true AND STARTSWITH(STRING_COL, 'abc')";
testQuery(filterQuery, nonFilterQuery);
filterQuery =
- "SELECT SUM(INT_COL) FILTER(WHERE BOOLEAN_COL AND
STARTSWITH(REVERSE(STRING_COL), 'abc')) FROM " + "MyTable";
+ "SELECT SUM(INT_COL) FILTER(WHERE BOOLEAN_COL AND
STARTSWITH(REVERSE(STRING_COL), 'abc')) sum1 FROM MyTable";
+ nonFilterQuery =
+ "SELECT SUM(INT_COL) sum1 FROM MyTable WHERE BOOLEAN_COL=true AND
STARTSWITH(REVERSE(STRING_COL), " + "'abc')";
+ testQuery(filterQuery, nonFilterQuery);
+ }
+
+ @Test
+ public void testFilterResultColumnNameGroupBy() {
+ String filterQuery =
+ "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 9999) FROM MyTable WHERE
INT_COL < 1000000 GROUP BY BOOLEAN_COL";
+ String nonFilterQuery =
+ "SELECT SUM(INT_COL) \"sum(INT_COL) FILTER(WHERE INT_COL > '9999')\"
FROM MyTable WHERE INT_COL > 9999 AND "
+ + "INT_COL < 1000000 GROUP BY BOOLEAN_COL";
+ testQuery(filterQuery, nonFilterQuery);
+
+ filterQuery =
+ "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 9999 AND INT_COL <
1000000) FROM MyTable GROUP BY BOOLEAN_COL";
+ nonFilterQuery =
+ "SELECT SUM(INT_COL) \"sum(INT_COL) FILTER(WHERE (INT_COL > '9999' AND
INT_COL < '1000000'))\" FROM MyTable "
+ + "WHERE INT_COL > 9999 AND INT_COL < 1000000 GROUP BY
BOOLEAN_COL";
+ testQuery(filterQuery, nonFilterQuery);
+ }
+
+ @Test
+ public void testFilterResultColumnNameNonGroupBy() {
+ String filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 9999)
FROM MyTable WHERE INT_COL < 1000000";
+ String nonFilterQuery =
+ "SELECT SUM(INT_COL) \"sum(INT_COL) FILTER(WHERE INT_COL > '9999')\"
FROM MyTable WHERE INT_COL > 9999 AND "
+ + "INT_COL < 1000000";
+ testQuery(filterQuery, nonFilterQuery);
+
+ filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 9999 AND INT_COL
< 1000000) FROM MyTable";
nonFilterQuery =
- "SELECT SUM(INT_COL) FROM MyTable WHERE BOOLEAN_COL=true AND
STARTSWITH(REVERSE(STRING_COL), " + "'abc')";
+ "SELECT SUM(INT_COL) \"sum(INT_COL) FILTER(WHERE (INT_COL > '9999' AND
INT_COL < '1000000'))\" FROM MyTable "
+ + "WHERE INT_COL > 9999 AND INT_COL < 1000000";
testQuery(filterQuery, nonFilterQuery);
}
@@ -305,9 +338,9 @@ public class FilteredAggregationsTest extends
BaseQueriesTest {
@Test
public void testMultipleAggregationsOnSameFilter() {
- String filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL >
29990), "
- + "MAX(INT_COL) FILTER(WHERE INT_COL > 29990) FROM MyTable";
- String nonFilterQuery = "SELECT MIN(INT_COL), MAX(INT_COL) FROM MyTable
WHERE INT_COL > 29990";
+ String filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL >
29990) testMin, "
+ + "MAX(INT_COL) FILTER(WHERE INT_COL > 29990) testMax FROM MyTable";
+ String nonFilterQuery = "SELECT MIN(INT_COL) testMin, MAX(INT_COL) testMax
FROM MyTable WHERE INT_COL > 29990";
testQuery(filterQuery, nonFilterQuery);
filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 29990) AS
total_min, "
@@ -321,6 +354,26 @@ public class FilteredAggregationsTest extends
BaseQueriesTest {
testQuery(filterQuery, nonFilterQuery);
}
+ @Test
+ public void testMultipleAggregationsOnSameFilterOrderByFiltered() {
+ String filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL >
29990) testMin, "
+ + "MAX(INT_COL) FILTER(WHERE INT_COL > 29990) testMax FROM MyTable
ORDER BY testMax";
+ String nonFilterQuery =
+ "SELECT MIN(INT_COL) testMin, MAX(INT_COL) testMax FROM MyTable WHERE
INT_COL > 29990 ORDER BY testMax";
+ testQuery(filterQuery, nonFilterQuery);
+
+ filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 29990) AS
total_min, "
+ + "MAX(INT_COL) FILTER(WHERE INT_COL > 29990) AS total_max, "
+ + "SUM(INT_COL) FILTER(WHERE NO_INDEX_COL < 5000) AS total_sum, "
+ + "MAX(NO_INDEX_COL) FILTER(WHERE NO_INDEX_COL < 5000) AS total_max2
FROM MyTable ORDER BY total_sum";
+ nonFilterQuery = "SELECT MIN(CASE WHEN (NO_INDEX_COL > 29990) THEN INT_COL
ELSE 99999 END) AS total_min, "
+ + "MAX(CASE WHEN (INT_COL > 29990) THEN INT_COL ELSE 0 END) AS
total_max, "
+ + "SUM(CASE WHEN (NO_INDEX_COL < 5000) THEN INT_COL ELSE 0 END) AS
total_sum, "
+ + "MAX(CASE WHEN (NO_INDEX_COL < 5000) THEN NO_INDEX_COL ELSE 0 END)
AS total_max2 FROM MyTable ORDER BY "
+ + "total_sum";
+ testQuery(filterQuery, nonFilterQuery);
+ }
+
@Test
public void testMixedAggregationsOfSameType() {
String filterQuery = "SELECT SUM(INT_COL), SUM(INT_COL) FILTER(WHERE
INT_COL > 25000) AS total_sum FROM MyTable";
@@ -337,8 +390,8 @@ public class FilteredAggregationsTest extends
BaseQueriesTest {
@Test
public void testGroupBy() {
- String filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 25000)
FROM MyTable GROUP BY BOOLEAN_COL";
- String nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE INT_COL >
25000 GROUP BY BOOLEAN_COL";
+ String filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 25000)
testSum FROM MyTable GROUP BY BOOLEAN_COL";
+ String nonFilterQuery = "SELECT SUM(INT_COL) testSum FROM MyTable WHERE
INT_COL > 25000 GROUP BY BOOLEAN_COL";
testQuery(filterQuery, nonFilterQuery);
}
@@ -356,17 +409,19 @@ public class FilteredAggregationsTest extends
BaseQueriesTest {
@Test
public void testGroupBySameFilter() {
String filterQuery =
- "SELECT AVG(INT_COL) FILTER(WHERE INT_COL > 25000), SUM(INT_COL)
FILTER(WHERE INT_COL > 25000) FROM MyTable "
- + "GROUP BY BOOLEAN_COL";
- String nonFilterQuery = "SELECT AVG(INT_COL), SUM(INT_COL) FROM MyTable
WHERE INT_COL > 25000 GROUP BY BOOLEAN_COL";
+ "SELECT AVG(INT_COL) FILTER(WHERE INT_COL > 25000) testAvg,
SUM(INT_COL) FILTER(WHERE INT_COL > 25000) "
+ + "testSum FROM MyTable GROUP BY BOOLEAN_COL";
+ String nonFilterQuery =
+ "SELECT AVG(INT_COL) testAvg, SUM(INT_COL) testSum FROM MyTable WHERE
INT_COL > 25000 GROUP BY BOOLEAN_COL";
testQuery(filterQuery, nonFilterQuery);
}
@Test
public void testMultipleAggregationsOnSameFilterGroupBy() {
- String filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL >
29990), "
- + "MAX(INT_COL) FILTER(WHERE INT_COL > 29990) FROM MyTable GROUP BY
BOOLEAN_COL";
- String nonFilterQuery = "SELECT MIN(INT_COL), MAX(INT_COL) FROM MyTable
WHERE INT_COL > 29990 GROUP BY BOOLEAN_COL";
+ String filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL >
29990) testMin, "
+ + "MAX(INT_COL) FILTER(WHERE INT_COL > 29990) testMax FROM MyTable
GROUP BY BOOLEAN_COL";
+ String nonFilterQuery =
+ "SELECT MIN(INT_COL) testMin, MAX(INT_COL) testMax FROM MyTable WHERE
INT_COL > 29990 GROUP BY BOOLEAN_COL";
testQuery(filterQuery, nonFilterQuery);
filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 29990) AS
total_min, "
@@ -380,4 +435,15 @@ public class FilteredAggregationsTest extends
BaseQueriesTest {
+ "BOOLEAN_COL";
testQuery(filterQuery, nonFilterQuery);
}
+
+ @Test
+ public void testGroupBySameFilterOrderByFiltered() {
+ String filterQuery =
+ "SELECT AVG(INT_COL) FILTER(WHERE INT_COL > 25000) testAvg,
SUM(INT_COL) FILTER(WHERE INT_COL > 25000) "
+ + "testSum FROM MyTable GROUP BY BOOLEAN_COL ORDER BY testAvg";
+ String nonFilterQuery =
+ "SELECT AVG(INT_COL) testAvg, SUM(INT_COL) testSum FROM MyTable WHERE
INT_COL > 25000 GROUP BY BOOLEAN_COL "
+ + "ORDER BY testAvg";
+ testQuery(filterQuery, nonFilterQuery);
+ }
}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java
b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java
index 668aecd0ba..760c1c78c1 100644
---
a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java
@@ -519,7 +519,8 @@ public class InterSegmentAggregationMultiValueQueriesTest
extends BaseMultiValue
public void testFilteredAggregations() {
String query = "SELECT COUNT(*) FILTER(WHERE column1 > 5) FROM testTable
WHERE column3 > 0";
BrokerResponseNative brokerResponse = getBrokerResponse(query);
- DataSchema expectedDataSchema = new DataSchema(new String[]{"count(*)"},
new ColumnDataType[]{ColumnDataType.LONG});
+ DataSchema expectedDataSchema = new DataSchema(new String[]{"count(*)
FILTER(WHERE column1 > '5')"},
+ new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.LONG});
ResultTable expectedResultTable =
new ResultTable(expectedDataSchema, Collections.singletonList(new
Object[]{370236L}));
QueriesTestUtils.testInterSegmentsResult(brokerResponse, 740472L, 400000L,
0L, 400000L, expectedResultTable);
diff --git
a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueRawQueriesTest.java
b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueRawQueriesTest.java
index 7b4325df6d..06d89e6573 100644
---
a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueRawQueriesTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueRawQueriesTest.java
@@ -530,8 +530,8 @@ public class
InterSegmentAggregationMultiValueRawQueriesTest extends BaseMultiVa
public void testFilteredAggregations() {
String query = "SELECT COUNT(*) FILTER(WHERE column1 > 5) FROM testTable
WHERE column3 > 0";
BrokerResponseNative brokerResponse = getBrokerResponse(query);
- DataSchema expectedDataSchema = new DataSchema(new String[]{"count(*)"},
new DataSchema.ColumnDataType[]
- {DataSchema.ColumnDataType.LONG});
+ DataSchema expectedDataSchema = new DataSchema(new String[]{"count(*)
FILTER(WHERE column1 > '5')"},
+ new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.LONG});
ResultTable expectedResultTable =
new ResultTable(expectedDataSchema, Collections.singletonList(new
Object[]{370236L}));
QueriesTestUtils.testInterSegmentsResult(brokerResponse, 740472L, 400000L,
0L, 400000L, expectedResultTable);
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]