This is an automated email from the ASF dual-hosted git repository.
jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 8042408 Fix filtered aggregation when it is mixed with regular
aggregation (#8172)
8042408 is described below
commit 80424086e701a0962656d1bfec2b9965cc36d22d
Author: Xiaotian (Jackie) Jiang <[email protected]>
AuthorDate: Tue Feb 15 11:50:53 2022 -0800
Fix filtered aggregation when it is mixed with regular aggregation (#8172)
When a query has the same mixed regular aggregation and filtered
aggregation, the index of the aggregation function is not maintained correctly.
This PR contains the following changes:
- Fix `QueryContext` to maintain the correct index
- Fix `PostAggregationHandler` to always use the filtered aggregation index
map
- Add tests for `QueryContext` and queries
---
.../core/query/reduce/PostAggregationHandler.java | 13 +-
.../core/query/request/context/QueryContext.java | 110 +++++++---------
.../BrokerRequestToQueryContextConverterTest.java | 146 +++++++++++++++++++--
.../pinot/queries/FilteredAggregationsTest.java | 14 ++
4 files changed, 204 insertions(+), 79 deletions(-)
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/PostAggregationHandler.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/PostAggregationHandler.java
index 85c2565..d1a3cbf 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/PostAggregationHandler.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/PostAggregationHandler.java
@@ -39,7 +39,6 @@ import org.apache.pinot.core.util.GapfillUtils;
* aggregation result.
*/
public class PostAggregationHandler {
- private final Map<FunctionContext, Integer> _aggregationFunctionIndexMap;
private final Map<Pair<FunctionContext, FilterContext>, Integer>
_filteredAggregationsIndexMap;
private final int _numGroupByExpressions;
private final Map<ExpressionContext, Integer> _groupByExpressionIndexMap;
@@ -48,9 +47,8 @@ public class PostAggregationHandler {
private final DataSchema _resultDataSchema;
public PostAggregationHandler(QueryContext queryContext, DataSchema
dataSchema) {
- _aggregationFunctionIndexMap =
queryContext.getAggregationFunctionIndexMap();
_filteredAggregationsIndexMap =
queryContext.getFilteredAggregationsIndexMap();
- assert _aggregationFunctionIndexMap != null;
+ assert _filteredAggregationsIndexMap != null;
List<ExpressionContext> groupByExpressions =
queryContext.getGroupByExpressions();
if (groupByExpressions != null) {
_numGroupByExpressions = groupByExpressions.size();
@@ -121,14 +119,15 @@ public class PostAggregationHandler {
expression);
if (function.getType() == FunctionContext.Type.AGGREGATION) {
// Aggregation function
- return new
ColumnValueExtractor(_aggregationFunctionIndexMap.get(function) +
_numGroupByExpressions);
+ return new ColumnValueExtractor(
+ _filteredAggregationsIndexMap.get(Pair.of(function, null)) +
_numGroupByExpressions);
} else if (function.getType() == FunctionContext.Type.TRANSFORM &&
function.getFunctionName()
.equalsIgnoreCase("filter")) {
+ FunctionContext aggregation =
function.getArguments().get(0).getFunction();
ExpressionContext filterExpression = function.getArguments().get(1);
FilterContext filter = RequestContextUtils.getFilter(filterExpression);
- FunctionContext filterFunction =
function.getArguments().get(0).getFunction();
-
- return new
ColumnValueExtractor(_filteredAggregationsIndexMap.get(Pair.of(filterFunction,
filter)));
+ return new ColumnValueExtractor(
+ _filteredAggregationsIndexMap.get(Pair.of(aggregation, filter)) +
_numGroupByExpressions);
} else {
// Post-aggregation function
return new PostAggregationValueExtractor(function);
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java
index e3a08a2..3a396b6 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java
@@ -457,82 +457,68 @@ public class QueryContext {
* Helper method to generate the aggregation functions for the query.
*/
private void generateAggregationFunctions(QueryContext queryContext) {
- List<AggregationFunction> aggregationFunctions = new ArrayList<>();
- List<Pair<AggregationFunction, FilterContext>> filteredAggregations =
new ArrayList<>();
- Map<FunctionContext, Integer> aggregationFunctionIndexMap = new
HashMap<>();
- Map<Pair<FunctionContext, FilterContext>, Integer>
filterExpressionIndexMap = new HashMap<>();
+ List<Pair<AggregationFunction, FilterContext>>
filteredAggregationFunctions = new ArrayList<>();
+ Map<Pair<FunctionContext, FilterContext>, Integer>
filteredAggregationsIndexMap = new HashMap<>();
// Add aggregation functions in the SELECT clause
// NOTE: DO NOT deduplicate the aggregation functions in the SELECT
clause because that involves protocol change.
- List<Pair<FilterContext, FunctionContext>> aggregationsInSelect = new
ArrayList<>();
+ List<Pair<FunctionContext, FilterContext>> filteredAggregations = new
ArrayList<>();
for (ExpressionContext selectExpression :
queryContext._selectExpressions) {
- getAggregations(selectExpression, aggregationsInSelect);
+ getAggregations(selectExpression, filteredAggregations);
}
- for (Pair<FilterContext, FunctionContext> pair : aggregationsInSelect) {
- FunctionContext function = pair.getRight();
- int functionIndex = filteredAggregations.size();
- AggregationFunction aggregationFunction =
- AggregationFunctionFactory.getAggregationFunction(function,
queryContext);
-
- FilterContext filterContext = null;
- // If the left pair is not null, implies a filtered aggregation
- if (pair.getLeft() != null) {
+ for (Pair<FunctionContext, FilterContext> pair : filteredAggregations) {
+ FunctionContext aggregation = pair.getLeft();
+ FilterContext filter = pair.getRight();
+ if (filter != null) {
+ // Filtered aggregation
if (_groupByExpressions != null) {
throw new IllegalStateException("GROUP BY with FILTER clauses is
not supported");
}
queryContext._hasFilteredAggregations = true;
- filterContext = pair.getLeft();
- Pair<FunctionContext, FilterContext> filterContextPair =
Pair.of(function, filterContext);
- if (!filterExpressionIndexMap.containsKey(filterContextPair)) {
- int filterMapIndex = filterExpressionIndexMap.size();
- filterExpressionIndexMap.put(filterContextPair, filterMapIndex);
- }
}
- filteredAggregations.add(Pair.of(aggregationFunction, filterContext));
- aggregationFunctionIndexMap.put(function, functionIndex);
+ int functionIndex = filteredAggregationFunctions.size();
+ AggregationFunction aggregationFunction =
+ AggregationFunctionFactory.getAggregationFunction(aggregation,
queryContext);
+ filteredAggregationFunctions.add(Pair.of(aggregationFunction, filter));
+ filteredAggregationsIndexMap.put(Pair.of(aggregation, filter),
functionIndex);
}
- // Add aggregation functions in the HAVING clause but not in the SELECT
clause
+ // Add aggregation functions in the HAVING and ORDER-BY clause but not
in the SELECT clause
+ filteredAggregations.clear();
if (queryContext._havingFilter != null) {
- List<Pair<FilterContext, FunctionContext>> aggregationsInHaving = new
ArrayList<>();
- getAggregations(queryContext._havingFilter, aggregationsInHaving);
- for (Pair<FilterContext, FunctionContext> pair : aggregationsInHaving)
{
- FunctionContext function = pair.getRight();
- if (!aggregationFunctionIndexMap.containsKey(function)) {
- int functionIndex = filteredAggregations.size();
- filteredAggregations.add(
-
Pair.of(AggregationFunctionFactory.getAggregationFunction(function,
queryContext), null));
- aggregationFunctionIndexMap.put(function, functionIndex);
- }
- }
+ getAggregations(queryContext._havingFilter, filteredAggregations);
}
-
- // Add aggregation functions in the ORDER-BY clause but not in the
SELECT or HAVING clause
if (queryContext._orderByExpressions != null) {
- List<Pair<FilterContext, FunctionContext>> aggregationsInOrderBy = new
ArrayList<>();
for (OrderByExpressionContext orderByExpression :
queryContext._orderByExpressions) {
- getAggregations(orderByExpression.getExpression(),
aggregationsInOrderBy);
+ getAggregations(orderByExpression.getExpression(),
filteredAggregations);
}
- for (Pair<FilterContext, FunctionContext> pair :
aggregationsInOrderBy) {
- FunctionContext function = pair.getRight();
- if (!aggregationFunctionIndexMap.containsKey(function)) {
- int functionIndex = filteredAggregations.size();
- filteredAggregations.add(
-
Pair.of(AggregationFunctionFactory.getAggregationFunction(function,
queryContext), null));
- aggregationFunctionIndexMap.put(function, functionIndex);
- }
+ }
+ for (Pair<FunctionContext, FilterContext> pair : filteredAggregations) {
+ if (!filteredAggregationsIndexMap.containsKey(pair)) {
+ FunctionContext aggregation = pair.getLeft();
+ FilterContext filter = pair.getRight();
+ int functionIndex = filteredAggregationFunctions.size();
+ AggregationFunction aggregationFunction =
+ AggregationFunctionFactory.getAggregationFunction(aggregation,
queryContext);
+ filteredAggregationFunctions.add(Pair.of(aggregationFunction,
filter));
+ filteredAggregationsIndexMap.put(Pair.of(aggregation, filter),
functionIndex);
}
}
- if (!filteredAggregations.isEmpty()) {
- for (Pair<AggregationFunction, FilterContext> pair :
filteredAggregations) {
- aggregationFunctions.add(pair.getLeft());
+ if (!filteredAggregationFunctions.isEmpty()) {
+ int numAggregations = filteredAggregationFunctions.size();
+ AggregationFunction[] aggregationFunctions = new
AggregationFunction[numAggregations];
+ for (int i = 0; i < numAggregations; i++) {
+ aggregationFunctions[i] =
filteredAggregationFunctions.get(i).getLeft();
}
-
- queryContext._aggregationFunctions = aggregationFunctions.toArray(new
AggregationFunction[0]);
- queryContext._filteredAggregationFunctions = filteredAggregations;
+ Map<FunctionContext, Integer> aggregationFunctionIndexMap = new
HashMap<>();
+ for (Map.Entry<Pair<FunctionContext, FilterContext>, Integer> entry :
filteredAggregationsIndexMap.entrySet()) {
+ aggregationFunctionIndexMap.put(entry.getKey().getLeft(),
entry.getValue());
+ }
+ queryContext._aggregationFunctions = aggregationFunctions;
+ queryContext._filteredAggregationFunctions =
filteredAggregationFunctions;
queryContext._aggregationFunctionIndexMap =
aggregationFunctionIndexMap;
- queryContext._filteredAggregationsIndexMap = filterExpressionIndexMap;
+ queryContext._filteredAggregationsIndexMap =
filteredAggregationsIndexMap;
}
}
@@ -540,14 +526,14 @@ public class QueryContext {
* Helper method to extract AGGREGATION FunctionContexts and FILTER
FilterContexts from the given expression.
*/
private static void getAggregations(ExpressionContext expression,
- List<Pair<FilterContext, FunctionContext>> aggregations) {
+ List<Pair<FunctionContext, FilterContext>> filteredAggregations) {
FunctionContext function = expression.getFunction();
if (function == null) {
return;
}
if (function.getType() == FunctionContext.Type.AGGREGATION) {
// Aggregation
- aggregations.add(Pair.of(null, function));
+ filteredAggregations.add(Pair.of(function, null));
} else {
List<ExpressionContext> arguments = function.getArguments();
if (function.getFunctionName().equalsIgnoreCase("filter")) {
@@ -561,12 +547,11 @@ public class QueryContext {
&& filterExpression.getFunction().getType() ==
FunctionContext.Type.TRANSFORM,
"Second argument of FILTER must be a filter expression");
FilterContext filter =
RequestContextUtils.getFilter(filterExpression);
-
- aggregations.add(Pair.of(filter, aggregation));
+ filteredAggregations.add(Pair.of(aggregation, filter));
} else {
// Transform
for (ExpressionContext argument : arguments) {
- getAggregations(argument, aggregations);
+ getAggregations(argument, filteredAggregations);
}
}
}
@@ -575,14 +560,15 @@ public class QueryContext {
/**
* Helper method to extract AGGREGATION FunctionContexts and FILTER
FilterContexts from the given filter.
*/
- private static void getAggregations(FilterContext filter,
List<Pair<FilterContext, FunctionContext>> aggregations) {
+ private static void getAggregations(FilterContext filter,
+ List<Pair<FunctionContext, FilterContext>> filteredAggregations) {
List<FilterContext> children = filter.getChildren();
if (children != null) {
for (FilterContext child : children) {
- getAggregations(child, aggregations);
+ getAggregations(child, filteredAggregations);
}
} else {
- getAggregations(filter.getPredicate().getLhs(), aggregations);
+ getAggregations(filter.getPredicate().getLhs(), filteredAggregations);
}
}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverterTest.java
b/pinot-core/src/test/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverterTest.java
index 1f690c1..794d705 100644
---
a/pinot-core/src/test/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverterTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverterTest.java
@@ -39,6 +39,8 @@ import
org.apache.pinot.common.request.context.predicate.RangePredicate;
import org.apache.pinot.common.request.context.predicate.TextMatchPredicate;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import
org.apache.pinot.core.query.aggregation.function.CountAggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.MinAggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.SumAggregationFunction;
import org.apache.pinot.core.query.request.context.QueryContext;
import org.apache.pinot.pql.parsers.Pql2Compiler;
import org.testng.annotations.Test;
@@ -579,16 +581,140 @@ public class BrokerRequestToQueryContextConverterTest {
@Test
public void testFilteredAggregations() {
- String query = "SELECT COUNT(*) FILTER(WHERE foo > 5), COUNT(*)
FILTER(WHERE foo < 6) FROM testTable WHERE bar > 0";
- QueryContext queryContext =
QueryContextConverterUtils.getQueryContextFromSQL(query);
- List<Pair<AggregationFunction, FilterContext>> filteredAggregationList =
- queryContext.getFilteredAggregationFunctions();
- assertNotNull(filteredAggregationList);
- assertEquals(filteredAggregationList.size(), 2);
- assertTrue(filteredAggregationList.get(0).getLeft() instanceof
CountAggregationFunction);
- assertEquals(filteredAggregationList.get(0).getRight().toString(), "foo >
'5'");
- assertTrue(filteredAggregationList.get(1).getLeft() instanceof
CountAggregationFunction);
- assertEquals(filteredAggregationList.get(1).getRight().toString(), "foo <
'6'");
+ {
+ String query =
+ "SELECT COUNT(*) FILTER(WHERE foo > 5), COUNT(*) FILTER(WHERE foo <
6) FROM testTable WHERE bar > 0";
+ QueryContext queryContext =
QueryContextConverterUtils.getQueryContextFromSQL(query);
+
+ AggregationFunction[] aggregationFunctions =
queryContext.getAggregationFunctions();
+ assertNotNull(aggregationFunctions);
+ assertEquals(aggregationFunctions.length, 2);
+ assertTrue(aggregationFunctions[0] instanceof CountAggregationFunction);
+ assertTrue(aggregationFunctions[1] instanceof CountAggregationFunction);
+
+ List<Pair<AggregationFunction, FilterContext>>
filteredAggregationFunctions =
+ queryContext.getFilteredAggregationFunctions();
+ assertNotNull(filteredAggregationFunctions);
+ assertEquals(filteredAggregationFunctions.size(), 2);
+ assertTrue(filteredAggregationFunctions.get(0).getLeft() instanceof
CountAggregationFunction);
+ assertEquals(filteredAggregationFunctions.get(0).getRight().toString(),
"foo > '5'");
+ assertTrue(filteredAggregationFunctions.get(1).getLeft() instanceof
CountAggregationFunction);
+ assertEquals(filteredAggregationFunctions.get(1).getRight().toString(),
"foo < '6'");
+
+ Map<FunctionContext, Integer> aggregationIndexMap =
queryContext.getAggregationFunctionIndexMap();
+ assertNotNull(aggregationIndexMap);
+ assertEquals(aggregationIndexMap.size(), 1);
+ for (Map.Entry<FunctionContext, Integer> entry :
aggregationIndexMap.entrySet()) {
+ FunctionContext aggregation = entry.getKey();
+ int index = entry.getValue();
+ assertEquals(aggregation.toString(), "count(*)");
+ assertTrue(index == 0 || index == 1);
+ }
+
+ Map<Pair<FunctionContext, FilterContext>, Integer>
filteredAggregationsIndexMap =
+ queryContext.getFilteredAggregationsIndexMap();
+ assertNotNull(filteredAggregationsIndexMap);
+ assertEquals(filteredAggregationsIndexMap.size(), 2);
+ for (Map.Entry<Pair<FunctionContext, FilterContext>, Integer> entry :
filteredAggregationsIndexMap.entrySet()) {
+ Pair<FunctionContext, FilterContext> pair = entry.getKey();
+ FunctionContext aggregation = pair.getLeft();
+ FilterContext filter = pair.getRight();
+ int index = entry.getValue();
+ assertEquals(aggregation.toString(), "count(*)");
+ switch (index) {
+ case 0:
+ assertEquals(filter.toString(), "foo > '5'");
+ break;
+ case 1:
+ assertEquals(filter.toString(), "foo < '6'");
+ break;
+ default:
+ fail();
+ break;
+ }
+ }
+ }
+
+ {
+ String query =
+ "SELECT SUM(salary), SUM(salary) FILTER(WHERE salary IS NOT NULL),
MIN(salary), MIN(salary) FILTER(WHERE "
+ + "salary > 50000) FROM testTable WHERE bar > 0";
+ QueryContext queryContext =
QueryContextConverterUtils.getQueryContextFromSQL(query);
+
+ AggregationFunction[] aggregationFunctions =
queryContext.getAggregationFunctions();
+ assertNotNull(aggregationFunctions);
+ assertEquals(aggregationFunctions.length, 4);
+ assertTrue(aggregationFunctions[0] instanceof SumAggregationFunction);
+ assertTrue(aggregationFunctions[1] instanceof SumAggregationFunction);
+ assertTrue(aggregationFunctions[2] instanceof MinAggregationFunction);
+ assertTrue(aggregationFunctions[3] instanceof MinAggregationFunction);
+
+ List<Pair<AggregationFunction, FilterContext>>
filteredAggregationFunctions =
+ queryContext.getFilteredAggregationFunctions();
+ assertNotNull(filteredAggregationFunctions);
+ assertEquals(filteredAggregationFunctions.size(), 4);
+ assertTrue(filteredAggregationFunctions.get(0).getLeft() instanceof
SumAggregationFunction);
+ assertNull(filteredAggregationFunctions.get(0).getRight());
+ assertTrue(filteredAggregationFunctions.get(1).getLeft() instanceof
SumAggregationFunction);
+ assertEquals(filteredAggregationFunctions.get(1).getRight().toString(),
"salary IS NOT NULL");
+ assertTrue(filteredAggregationFunctions.get(2).getLeft() instanceof
MinAggregationFunction);
+ assertNull(filteredAggregationFunctions.get(2).getRight());
+ assertTrue(filteredAggregationFunctions.get(3).getLeft() instanceof
MinAggregationFunction);
+ assertEquals(filteredAggregationFunctions.get(3).getRight().toString(),
"salary > '50000'");
+
+ Map<FunctionContext, Integer> aggregationIndexMap =
queryContext.getAggregationFunctionIndexMap();
+ assertNotNull(aggregationIndexMap);
+ assertEquals(aggregationIndexMap.size(), 2);
+ for (Map.Entry<FunctionContext, Integer> entry :
aggregationIndexMap.entrySet()) {
+ FunctionContext aggregation = entry.getKey();
+ int index = entry.getValue();
+ switch (index) {
+ case 0:
+ case 1:
+ assertEquals(aggregation.toString(), "sum(salary)");
+ break;
+ case 2:
+ case 3:
+ assertEquals(aggregation.toString(), "min(salary)");
+ break;
+ default:
+ fail();
+ break;
+ }
+ }
+
+ Map<Pair<FunctionContext, FilterContext>, Integer>
filteredAggregationsIndexMap =
+ queryContext.getFilteredAggregationsIndexMap();
+ assertNotNull(filteredAggregationsIndexMap);
+ assertEquals(filteredAggregationsIndexMap.size(), 4);
+ for (Map.Entry<Pair<FunctionContext, FilterContext>, Integer> entry :
filteredAggregationsIndexMap.entrySet()) {
+ Pair<FunctionContext, FilterContext> pair = entry.getKey();
+ FunctionContext aggregation = pair.getLeft();
+ FilterContext filter = pair.getRight();
+ int index = entry.getValue();
+ switch (index) {
+ case 0:
+ assertEquals(aggregation.toString(), "sum(salary)");
+ assertNull(filter);
+ break;
+ case 1:
+ assertEquals(aggregation.toString(), "sum(salary)");
+ assertEquals(filter.toString(), "salary IS NOT NULL");
+ break;
+ case 2:
+ assertEquals(aggregation.toString(), "min(salary)");
+ assertNull(filter);
+ break;
+ case 3:
+ assertEquals(aggregation.toString(), "min(salary)");
+ assertEquals(filter.toString(), "salary > '50000'");
+ break;
+ default:
+ fail();
+ break;
+ }
+ }
+ }
}
@Test
diff --git
a/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java
b/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java
index 4e101fe..763a731 100644
---
a/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java
@@ -299,6 +299,20 @@ public class FilteredAggregationsTest extends
BaseQueriesTest {
testQuery(filterQuery, nonFilterQuery);
}
+ @Test
+ public void testMixedAggregationsOfSameType() {
+ String filterQuery = "SELECT SUM(INT_COL), SUM(INT_COL) FILTER(WHERE
INT_COL > 25000) AS total_sum FROM MyTable";
+ String nonFilterQuery =
+ "SELECT SUM(INT_COL), SUM(CASE WHEN INT_COL > 25000 THEN INT_COL ELSE
0 END) AS total_sum FROM MyTable";
+ testQuery(filterQuery, nonFilterQuery);
+
+ filterQuery = "SELECT SUM(INT_COL), SUM(INT_COL) FILTER(WHERE INT_COL <
5000) AS total_sum, "
+ + "SUM(INT_COL) FILTER(WHERE INT_COL > 12345) AS total_sum2 FROM
MyTable";
+ nonFilterQuery = "SELECT SUM(INT_COL), SUM(CASE WHEN INT_COL < 5000 THEN
INT_COL ELSE 0 END) AS total_sum, "
+ + "SUM(CASE WHEN INT_COL > 12345 THEN INT_COL ELSE 0 END) AS
total_sum2 FROM MyTable";
+ testQuery(filterQuery, nonFilterQuery);
+ }
+
@Test(expectedExceptions = IllegalStateException.class)
public void testGroupBySupport() {
String filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 2),
MAX(INT_COL) FILTER(WHERE INT_COL > 2) "
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]