This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 147e05d7a8 Fix null handling for window aggregate (#13611)
147e05d7a8 is described below

commit 147e05d7a8c3593fd418a84ecac10aee99b72a90
Author: Xiaotian (Jackie) Jiang <[email protected]>
AuthorDate: Mon Jul 15 15:09:48 2024 -0700

    Fix null handling for window aggregate (#13611)
---
 .../runtime/operator/utils/AggregationUtils.java   | 32 ++++++++--------------
 .../window/aggregate/AggregateWindowFunction.java  | 11 ++++++--
 .../src/test/resources/queries/NullHandling.json   |  4 +++
 3 files changed, 25 insertions(+), 22 deletions(-)

diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java
index ed24af5a3c..ee2441fe70 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java
@@ -18,7 +18,6 @@
  */
 package org.apache.pinot.query.runtime.operator.utils;
 
-import com.google.common.collect.ImmutableMap;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -170,25 +169,18 @@ public class AggregationUtils {
    */
   public static class Accumulator {
     //@formatter:off
-    public static final Map<String, Function<DataSchema.ColumnDataType, 
AggregationUtils.Merger>> MERGERS =
-        ImmutableMap.<String, Function<DataSchema.ColumnDataType, 
AggregationUtils.Merger>>builder()
-            .put("SUM", cdt -> AggregationUtils::mergeSum)
-            .put("$SUM", cdt -> AggregationUtils::mergeSum)
-            .put("$SUM0", cdt -> AggregationUtils::mergeSum)
-            .put("MIN", cdt -> AggregationUtils::mergeMin)
-            .put("$MIN", cdt -> AggregationUtils::mergeMin)
-            .put("$MIN0", cdt -> AggregationUtils::mergeMin)
-            .put("MAX", cdt -> AggregationUtils::mergeMax)
-            .put("$MAX", cdt -> AggregationUtils::mergeMax)
-            .put("$MAX0", cdt -> AggregationUtils::mergeMax)
-            .put("COUNT", cdt -> new AggregationUtils.MergeCounts())
-            .put("BOOL_AND", cdt -> AggregationUtils::mergeBoolAnd)
-            .put("$BOOL_AND", cdt -> AggregationUtils::mergeBoolAnd)
-            .put("$BOOL_AND0", cdt -> AggregationUtils::mergeBoolAnd)
-            .put("BOOL_OR", cdt -> AggregationUtils::mergeBoolOr)
-            .put("$BOOL_OR", cdt -> AggregationUtils::mergeBoolOr)
-            .put("$BOOL_OR0", cdt -> AggregationUtils::mergeBoolOr)
-            .build();
+    public static final Map<String, Function<DataSchema.ColumnDataType, 
AggregationUtils.Merger>> MERGERS = Map.of(
+        "SUM", cdt -> AggregationUtils::mergeSum,
+        // NOTE: Keep both 'SUM0' and '$SUM0' for backward compatibility where 
'SUM0' is SqlKind and '$SUM0' is function
+        //       name.
+        "SUM0", cdt -> AggregationUtils::mergeSum,
+        "$SUM0", cdt -> AggregationUtils::mergeSum,
+        "MIN", cdt -> AggregationUtils::mergeMin,
+        "MAX", cdt -> AggregationUtils::mergeMax,
+        "COUNT", cdt -> new AggregationUtils.MergeCounts(),
+        "BOOL_AND", cdt -> AggregationUtils::mergeBoolAnd,
+        "BOOL_OR", cdt -> AggregationUtils::mergeBoolOr
+    );
     //@formatter:on
 
     protected final int _inputRef;
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/AggregateWindowFunction.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/AggregateWindowFunction.java
index fbf0afed77..6763542bd0 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/AggregateWindowFunction.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/AggregateWindowFunction.java
@@ -18,26 +18,33 @@
  */
 package org.apache.pinot.query.runtime.operator.window.aggregate;
 
+import com.google.common.base.Preconditions;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.function.Function;
 import org.apache.calcite.rel.RelFieldCollation;
 import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
 import org.apache.pinot.core.data.table.Key;
 import org.apache.pinot.query.planner.logical.RexExpression;
 import org.apache.pinot.query.runtime.operator.utils.AggregationUtils;
+import org.apache.pinot.query.runtime.operator.utils.AggregationUtils.Merger;
 import org.apache.pinot.query.runtime.operator.window.WindowFunction;
 
 
 public class AggregateWindowFunction extends WindowFunction {
-  private final AggregationUtils.Merger _merger;
+  private final Merger _merger;
 
   public AggregateWindowFunction(RexExpression.FunctionCall aggCall, 
DataSchema inputSchema,
       List<RelFieldCollation> collations, boolean partitionByOnly) {
     super(aggCall, inputSchema, collations, partitionByOnly);
-    _merger = 
AggregationUtils.Accumulator.MERGERS.get(aggCall.getFunctionName()).apply(_dataType);
+    String functionName = aggCall.getFunctionName();
+    Function<ColumnDataType, Merger> mergerCreator = 
AggregationUtils.Accumulator.MERGERS.get(functionName);
+    Preconditions.checkArgument(mergerCreator != null, "Unsupported aggregate 
function: %s", functionName);
+    _merger = mergerCreator.apply(_dataType);
   }
 
   @Override
diff --git a/pinot-query-runtime/src/test/resources/queries/NullHandling.json 
b/pinot-query-runtime/src/test/resources/queries/NullHandling.json
index ee15c88f34..95e6b6c0c5 100644
--- a/pinot-query-runtime/src/test/resources/queries/NullHandling.json
+++ b/pinot-query-runtime/src/test/resources/queries/NullHandling.json
@@ -309,6 +309,10 @@
         "sql": "SET enableNullHandling=true; SELECT strCol1, intCol1, 
nIntCol1, nnIntCol1, strCol2, nStrCol2, nnStrCol2 FROM {tbl1} WHERE nStrCol2 IS 
NULL AND nIntCol1 IS NOT NULL",
         "h2Sql": "SELECT strCol1, intCol1, nIntCol1, nnIntCol1, strCol2, 
nStrCol2, 'null' FROM {tbl1} WHERE nStrCol2 IS NULL AND nIntCol1 IS NOT NULL"
       },
+      {
+        "description": "window function with NULL handling",
+        "sql": "SET enableNullHandling=true; SELECT SUM(intCol1) OVER() FROM 
{tbl1}"
+      },
 
       {
         "description": "Leaf stages should not return nulls",


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to