This is an automated email from the ASF dual-hosted git repository.
rongr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new b4c8271757 [multistage] fix literal aggregate function not working
issue (#9331)
b4c8271757 is described below
commit b4c8271757101e5d1061e08d929d200823a69c95
Author: Rong Rong <[email protected]>
AuthorDate: Fri Sep 2 18:02:09 2022 -0700
[multistage] fix literal aggregate function not working issue (#9331)
* fix literal aggregate function not working issue for SUM(1) and COUNT(*)
Co-authored-by: Rong Rong <[email protected]>
---
.../query/runtime/operator/AggregateOperator.java | 30 ++++++++++++++++++----
.../pinot/query/runtime/QueryRunnerTest.java | 2 +-
2 files changed, 26 insertions(+), 6 deletions(-)
diff --git
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
index 2c33b9db4e..67c5edfa99 100644
---
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
+++
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
@@ -25,6 +25,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
+import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.Operator;
@@ -39,6 +40,7 @@ import
org.apache.pinot.core.query.selection.SelectionOperatorUtils;
import org.apache.pinot.query.planner.logical.RexExpression;
import org.apache.pinot.query.runtime.blocks.TransferableBlock;
import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
+import org.apache.pinot.spi.data.FieldSpec;
/**
@@ -53,6 +55,7 @@ public class AggregateOperator extends
BaseOperator<TransferableBlock> {
private final AggregationFunction[] _aggregationFunctions;
private final int[] _aggregationFunctionInputRefs;
+ private final Object[] _aggregationFunctionLiterals;
private final DataSchema _resultSchema;
private final Map<Integer, Object>[] _groupByResultHolders;
private final Map<Integer, Object[]> _groupByKeyHolder;
@@ -72,10 +75,18 @@ public class AggregateOperator extends
BaseOperator<TransferableBlock> {
_aggregationFunctions = new AggregationFunction[_aggCalls.size()];
_aggregationFunctionInputRefs = new int[_aggCalls.size()];
+ _aggregationFunctionLiterals = new Object[_aggCalls.size()];
_groupByResultHolders = new Map[_aggCalls.size()];
_groupByKeyHolder = new HashMap<Integer, Object[]>();
for (int i = 0; i < aggCalls.size(); i++) {
- _aggregationFunctionInputRefs[i] =
toAggregationFunctionRefIndex(aggCalls.get(i));
+ // agg function operand should either be a InputRef or a Literal
+ RexExpression rexExpression =
toAggregationFunctionOperand(aggCalls.get(i));
+ if (rexExpression instanceof RexExpression.InputRef) {
+ _aggregationFunctionInputRefs[i] = ((RexExpression.InputRef)
rexExpression).getIndex();
+ } else {
+ _aggregationFunctionInputRefs[i] = -1;
+ _aggregationFunctionLiterals[i] = ((RexExpression.Literal)
rexExpression).getValue();
+ }
_aggregationFunctions[i] = toAggregationFunction(aggCalls.get(i),
_aggregationFunctionInputRefs[i]);
_groupByResultHolders[i] = new HashMap<Integer, Object>();
}
@@ -84,10 +95,11 @@ public class AggregateOperator extends
BaseOperator<TransferableBlock> {
_isCumulativeBlockConstructed = false;
}
- private int toAggregationFunctionRefIndex(RexExpression rexExpression) {
+ private RexExpression toAggregationFunctionOperand(RexExpression
rexExpression) {
List<RexExpression> functionOperands = ((RexExpression.FunctionCall)
rexExpression).getFunctionOperands();
Preconditions.checkState(functionOperands.size() < 2);
- return functionOperands.size() == 0 ? 0 : ((RexExpression.InputRef)
functionOperands.get(0)).getIndex();
+ return functionOperands.size() > 0 ? functionOperands.get(0)
+ : new RexExpression.Literal(FieldSpec.DataType.INT,
SqlTypeName.INTEGER, 1);
}
@Override
@@ -155,10 +167,12 @@ public class AggregateOperator extends
BaseOperator<TransferableBlock> {
for (int i = 0; i < _aggregationFunctions.length; i++) {
Object currentRes = _groupByResultHolders[i].get(keyHashCode);
if (currentRes == null) {
- _groupByResultHolders[i].put(keyHashCode,
row[_aggregationFunctionInputRefs[i]]);
+ _groupByResultHolders[i].put(keyHashCode,
_aggregationFunctionInputRefs[i] == -1
+ ? _aggregationFunctionLiterals[i] :
row[_aggregationFunctionInputRefs[i]]);
} else {
_groupByResultHolders[i].put(keyHashCode,
- merge(_aggCalls.get(i), currentRes,
row[_aggregationFunctionInputRefs[i]]));
+ merge(_aggCalls.get(i), currentRes,
_aggregationFunctionInputRefs[i] == -1
+ ? _aggregationFunctionLiterals[i] :
row[_aggregationFunctionInputRefs[i]]));
}
}
}
@@ -190,6 +204,9 @@ public class AggregateOperator extends
BaseOperator<TransferableBlock> {
case "MAX":
return new MaxAggregationFunction(
ExpressionContext.forIdentifier(String.valueOf(aggregationFunctionInputRef)));
+ // COUNT(*) is rewritten to SUM(1)
+ case "COUNT":
+ return new SumAggregationFunction(ExpressionContext.forLiteral("1"));
default:
throw new IllegalStateException(
"Unexpected value: " + ((RexExpression.FunctionCall)
aggCall).getFunctionName());
@@ -211,6 +228,9 @@ public class AggregateOperator extends
BaseOperator<TransferableBlock> {
case "$MAX":
case "$MAX0":
return Math.max(((Number) left).doubleValue(), ((Number)
right).doubleValue());
+ // COUNT(*) doesn't need to parse right object.
+ case "COUNT":
+ return ((Number) left).doubleValue() + 1;
default:
throw new IllegalStateException(
"Unexpected value: " + ((RexExpression.FunctionCall)
aggCall).getFunctionName());
diff --git
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java
index 3d89315868..462c5da0c3 100644
---
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java
+++
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java
@@ -137,7 +137,7 @@ public class QueryRunnerTest extends QueryRunnerTestBase {
// GROUP BY after JOIN
// only 3 GROUP BY key exist because b.col2 cycles between "foo",
"bar", "alice".
- new Object[]{"SELECT a.col1, SUM(b.col3) FROM a JOIN b ON a.col1 =
b.col2 "
+ new Object[]{"SELECT a.col1, SUM(b.col3), COUNT(*), SUM(2) FROM a JOIN
b ON a.col1 = b.col2 "
+ " WHERE a.col3 >= 0 GROUP BY a.col1", 3},
// Sub-query
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]