This is an automated email from the ASF dual-hosted git repository.

rongr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new b4c8271757 [multistage] fix literal aggregate function not working 
issue (#9331)
b4c8271757 is described below

commit b4c8271757101e5d1061e08d929d200823a69c95
Author: Rong Rong <[email protected]>
AuthorDate: Fri Sep 2 18:02:09 2022 -0700

    [multistage] fix literal aggregate function not working issue (#9331)
    
    * fix literal aggregate function not working issue for SUM(1) and COUNT(*)
    
    Co-authored-by: Rong Rong <[email protected]>
---
 .../query/runtime/operator/AggregateOperator.java  | 30 ++++++++++++++++++----
 .../pinot/query/runtime/QueryRunnerTest.java       |  2 +-
 2 files changed, 26 insertions(+), 6 deletions(-)

diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
index 2c33b9db4e..67c5edfa99 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
@@ -25,6 +25,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import javax.annotation.Nullable;
+import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.pinot.common.request.context.ExpressionContext;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.core.common.Operator;
@@ -39,6 +40,7 @@ import 
org.apache.pinot.core.query.selection.SelectionOperatorUtils;
 import org.apache.pinot.query.planner.logical.RexExpression;
 import org.apache.pinot.query.runtime.blocks.TransferableBlock;
 import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
+import org.apache.pinot.spi.data.FieldSpec;
 
 
 /**
@@ -53,6 +55,7 @@ public class AggregateOperator extends 
BaseOperator<TransferableBlock> {
 
   private final AggregationFunction[] _aggregationFunctions;
   private final int[] _aggregationFunctionInputRefs;
+  private final Object[] _aggregationFunctionLiterals;
   private final DataSchema _resultSchema;
   private final Map<Integer, Object>[] _groupByResultHolders;
   private final Map<Integer, Object[]> _groupByKeyHolder;
@@ -72,10 +75,18 @@ public class AggregateOperator extends 
BaseOperator<TransferableBlock> {
 
     _aggregationFunctions = new AggregationFunction[_aggCalls.size()];
     _aggregationFunctionInputRefs = new int[_aggCalls.size()];
+    _aggregationFunctionLiterals = new Object[_aggCalls.size()];
     _groupByResultHolders = new Map[_aggCalls.size()];
     _groupByKeyHolder = new HashMap<Integer, Object[]>();
     for (int i = 0; i < aggCalls.size(); i++) {
-      _aggregationFunctionInputRefs[i] = 
toAggregationFunctionRefIndex(aggCalls.get(i));
+      // agg function operand should either be a InputRef or a Literal
+      RexExpression rexExpression = 
toAggregationFunctionOperand(aggCalls.get(i));
+      if (rexExpression instanceof RexExpression.InputRef) {
+        _aggregationFunctionInputRefs[i] = ((RexExpression.InputRef) 
rexExpression).getIndex();
+      } else {
+        _aggregationFunctionInputRefs[i] = -1;
+        _aggregationFunctionLiterals[i] = ((RexExpression.Literal) 
rexExpression).getValue();
+      }
       _aggregationFunctions[i] = toAggregationFunction(aggCalls.get(i), 
_aggregationFunctionInputRefs[i]);
       _groupByResultHolders[i] = new HashMap<Integer, Object>();
     }
@@ -84,10 +95,11 @@ public class AggregateOperator extends 
BaseOperator<TransferableBlock> {
     _isCumulativeBlockConstructed = false;
   }
 
-  private int toAggregationFunctionRefIndex(RexExpression rexExpression) {
+  private RexExpression toAggregationFunctionOperand(RexExpression 
rexExpression) {
     List<RexExpression> functionOperands = ((RexExpression.FunctionCall) 
rexExpression).getFunctionOperands();
     Preconditions.checkState(functionOperands.size() < 2);
-    return functionOperands.size() == 0 ? 0 : ((RexExpression.InputRef) 
functionOperands.get(0)).getIndex();
+    return functionOperands.size() > 0 ? functionOperands.get(0)
+        : new RexExpression.Literal(FieldSpec.DataType.INT, 
SqlTypeName.INTEGER, 1);
   }
 
   @Override
@@ -155,10 +167,12 @@ public class AggregateOperator extends 
BaseOperator<TransferableBlock> {
           for (int i = 0; i < _aggregationFunctions.length; i++) {
             Object currentRes = _groupByResultHolders[i].get(keyHashCode);
             if (currentRes == null) {
-              _groupByResultHolders[i].put(keyHashCode, 
row[_aggregationFunctionInputRefs[i]]);
+              _groupByResultHolders[i].put(keyHashCode, 
_aggregationFunctionInputRefs[i] == -1
+                  ? _aggregationFunctionLiterals[i] : 
row[_aggregationFunctionInputRefs[i]]);
             } else {
               _groupByResultHolders[i].put(keyHashCode,
-                  merge(_aggCalls.get(i), currentRes, 
row[_aggregationFunctionInputRefs[i]]));
+                  merge(_aggCalls.get(i), currentRes, 
_aggregationFunctionInputRefs[i] == -1
+                      ? _aggregationFunctionLiterals[i] : 
row[_aggregationFunctionInputRefs[i]]));
             }
           }
         }
@@ -190,6 +204,9 @@ public class AggregateOperator extends 
BaseOperator<TransferableBlock> {
       case "MAX":
         return new MaxAggregationFunction(
             
ExpressionContext.forIdentifier(String.valueOf(aggregationFunctionInputRef)));
+      // COUNT(*) is rewritten to SUM(1)
+      case "COUNT":
+        return new SumAggregationFunction(ExpressionContext.forLiteral("1"));
       default:
         throw new IllegalStateException(
             "Unexpected value: " + ((RexExpression.FunctionCall) 
aggCall).getFunctionName());
@@ -211,6 +228,9 @@ public class AggregateOperator extends 
BaseOperator<TransferableBlock> {
       case "$MAX":
       case "$MAX0":
         return Math.max(((Number) left).doubleValue(), ((Number) 
right).doubleValue());
+      // COUNT(*) doesn't need to parse right object.
+      case "COUNT":
+        return ((Number) left).doubleValue() + 1;
       default:
         throw new IllegalStateException(
             "Unexpected value: " + ((RexExpression.FunctionCall) 
aggCall).getFunctionName());
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java
index 3d89315868..462c5da0c3 100644
--- 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java
@@ -137,7 +137,7 @@ public class QueryRunnerTest extends QueryRunnerTestBase {
 
         // GROUP BY after JOIN
         // only 3 GROUP BY key exist because b.col2 cycles between "foo", 
"bar", "alice".
-        new Object[]{"SELECT a.col1, SUM(b.col3) FROM a JOIN b ON a.col1 = 
b.col2 "
+        new Object[]{"SELECT a.col1, SUM(b.col3), COUNT(*), SUM(2) FROM a JOIN 
b ON a.col1 = b.col2 "
             + " WHERE a.col3 >= 0 GROUP BY a.col1", 3},
 
         // Sub-query


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to