This is an automated email from the ASF dual-hosted git repository.

xiangfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 55b6024467 Fix literal handling in Window functions (#13428)
55b6024467 is described below

commit 55b6024467e5e4ad8f4a260453054c42eaac706d
Author: Xiang Fu <[email protected]>
AuthorDate: Tue Jun 18 18:10:55 2024 -0700

    Fix literal handling in Window functions (#13428)
---
 .../tests/OfflineClusterIntegrationTest.java       | 45 +++++++++++++++++++
 .../rules/PinotWindowExchangeNodeInsertRule.java   | 51 +++++++++++++++++++---
 2 files changed, 91 insertions(+), 5 deletions(-)

diff --git 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
index 49bf22b8e8..94145bb486 100644
--- 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
+++ 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
@@ -2113,6 +2113,51 @@ public class OfflineClusterIntegrationTest extends 
BaseClusterIntegrationTestSet
     assertEquals(row.get(0).asDouble(), 16071.0 / 2);
   }
 
+  @Test
+  public void testWindowAggregationV2()
+      throws Exception {
+    setUseMultiStageQueryEngine(true);
+    String tmpTableQuery =
+        "select DaysSinceEpoch, count(*) as num_trips from mytable GROUP BY 
DaysSinceEpoch order by DaysSinceEpoch";
+    JsonNode tmpTableResult = 
postQuery(tmpTableQuery).get("resultTable").get("rows");
+
+    String query = "WITH tmp AS (\n"
+        + "  select count(*) as num_trips, DaysSinceEpoch  from mytable GROUP 
BY DaysSinceEpoch\n"
+        + ")\n"
+        + "\n"
+        + "SELECT\n"
+        + "    DaysSinceEpoch,\n"
+        + "    num_trips,\n"
+        + "    LAG(num_trips, 2) OVER (ORDER BY DaysSinceEpoch) AS 
previous_num_trips,\n"
+        + "    num_trips - LAG(num_trips, 2) OVER (ORDER BY DaysSinceEpoch) AS 
difference\n"
+        + "FROM\n"
+        + "    tmp";
+    JsonNode response = postQuery(query);
+    JsonNode resultTable = response.get("resultTable");
+    
assertEquals(resultTable.get("dataSchema").get("columnDataTypes").toString(),
+        "[\"INT\",\"LONG\",\"LONG\",\"LONG\"]");
+    JsonNode rows = resultTable.get("rows");
+    assertEquals(rows.size(), 364);
+    for (int i = 0; i < 2; i++) {
+      JsonNode row = rows.get(i);
+      JsonNode tmpTableRow = tmpTableResult.get(i);
+      assertEquals(row.size(), 4);
+      assertEquals(row.get(0).asInt(), tmpTableRow.get(0).asInt());
+      assertEquals(row.get(1).asLong(), tmpTableRow.get(1).asLong());
+      assertTrue(row.get(2).isNull());
+      assertTrue(row.get(2).isNull());
+    }
+    for (int i = 2; i < 363; i++) {
+      JsonNode row = rows.get(i);
+      assertEquals(row.size(), 4);
+      JsonNode tmpTableRow = tmpTableResult.get(i);
+      assertEquals(row.get(0).asInt(), tmpTableRow.get(0).asInt());
+      assertEquals(row.get(1).asLong(), tmpTableRow.get(1).asLong());
+      assertEquals(rows.get(i - 2).get(1).asLong(), row.get(2).asLong());
+      assertEquals(row.get(1).asLong() - row.get(2).asLong(), 
row.get(3).asLong());
+    }
+  }
+
   @Test(dataProvider = "useBothQueryEngines")
   public void testSelectionUDF(boolean useMultiStageQueryEngine)
       throws Exception {
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotWindowExchangeNodeInsertRule.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotWindowExchangeNodeInsertRule.java
index e9caf1216f..c004aba293 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotWindowExchangeNodeInsertRule.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotWindowExchangeNodeInsertRule.java
@@ -41,6 +41,7 @@ import org.apache.calcite.rel.type.RelDataTypeField;
 import org.apache.calcite.rex.RexBuilder;
 import org.apache.calcite.rex.RexInputRef;
 import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlAggFunction;
 import org.apache.calcite.sql.SqlKind;
 import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.calcite.tools.RelBuilderFactory;
@@ -92,7 +93,7 @@ public class PinotWindowExchangeNodeInsertRule extends 
RelOptRule {
     // Perform all validations
     validateWindows(window);
 
-    Window.Group windowGroup = window.groups.get(0);
+    Window.Group windowGroup = updateLiteralArgumentsInWindowGroup(window);
     if (windowGroup.keys.isEmpty() && 
windowGroup.orderKeys.getKeys().isEmpty()) {
       // Empty OVER()
       // Add a single Exchange for empty OVER() since no sort is required
@@ -111,7 +112,8 @@ public class PinotWindowExchangeNodeInsertRule extends 
RelOptRule {
       PinotLogicalExchange exchange = PinotLogicalExchange.create(windowInput,
           RelDistributions.hash(Collections.emptyList()));
       call.transformTo(
-          LogicalWindow.create(window.getTraitSet(), exchange, 
window.constants, window.getRowType(), window.groups));
+          LogicalWindow.create(window.getTraitSet(), exchange, 
window.constants, window.getRowType(),
+              List.of(windowGroup)));
     } else if (windowGroup.keys.isEmpty() && 
!windowGroup.orderKeys.getKeys().isEmpty()) {
       // Only ORDER BY
       // Add a LogicalSortExchange with collation on the order by key(s) and 
an empty hash partition key
@@ -121,7 +123,7 @@ public class PinotWindowExchangeNodeInsertRule extends 
RelOptRule {
       PinotLogicalSortExchange sortExchange = 
PinotLogicalSortExchange.create(windowInput,
           RelDistributions.hash(Collections.emptyList()), 
windowGroup.orderKeys, false, true);
       call.transformTo(LogicalWindow.create(window.getTraitSet(), 
sortExchange, window.constants, window.getRowType(),
-          window.groups));
+          List.of(windowGroup)));
     } else {
       // All other variants
       // Assess whether this is a PARTITION BY only query or not (includes 
queries of the type where PARTITION BY and
@@ -134,7 +136,7 @@ public class PinotWindowExchangeNodeInsertRule extends 
RelOptRule {
         PinotLogicalExchange exchange = 
PinotLogicalExchange.create(windowInput,
             RelDistributions.hash(windowGroup.keys.toList()));
         call.transformTo(LogicalWindow.create(window.getTraitSet(), exchange, 
window.constants, window.getRowType(),
-            window.groups));
+            List.of(windowGroup)));
       } else {
         // PARTITION BY and ORDER BY on different key(s)
         // Add a LogicalSortExchange hashed on the partition by keys and 
collation based on order by keys
@@ -145,11 +147,50 @@ public class PinotWindowExchangeNodeInsertRule extends 
RelOptRule {
         PinotLogicalSortExchange sortExchange = 
PinotLogicalSortExchange.create(windowInput,
             RelDistributions.hash(windowGroup.keys.toList()), 
windowGroup.orderKeys, false, true);
         call.transformTo(LogicalWindow.create(window.getTraitSet(), 
sortExchange, window.constants, window.getRowType(),
-            window.groups));
+            List.of(windowGroup)));
       }
     }
   }
 
+  private Window.Group updateLiteralArgumentsInWindowGroup(Window window) {
+    Window.Group oldWindowGroup = window.groups.get(0);
+    int windowInputSize = window.getInput().getRowType().getFieldCount();
+    ImmutableList<Window.RexWinAggCall> oldAggCalls = oldWindowGroup.aggCalls;
+    List<Window.RexWinAggCall> newAggCallWindow = new 
ArrayList<>(oldAggCalls.size());
+    boolean aggCallChanged = false;
+    for (Window.RexWinAggCall oldAggCall : oldAggCalls) {
+      boolean changed = false;
+      List<RexNode> oldAggCallArgList = oldAggCall.getOperands();
+      List<RexNode> rexList = new ArrayList<>(oldAggCallArgList.size());
+      for (RexNode rexNode : oldAggCallArgList) {
+        RexNode newRexNode = rexNode;
+        if (rexNode instanceof RexInputRef) {
+          RexInputRef inputRef = (RexInputRef) rexNode;
+          int inputRefIndex = inputRef.getIndex();
+          // If the input reference is greater than the window input size, it 
is a reference to the constants
+          if (inputRefIndex >= windowInputSize) {
+            newRexNode = window.constants.get(inputRefIndex - windowInputSize);
+            changed = true;
+            aggCallChanged = true;
+          }
+        }
+        rexList.add(newRexNode);
+      }
+      if (changed) {
+        newAggCallWindow.add(
+            new Window.RexWinAggCall((SqlAggFunction) 
oldAggCall.getOperator(), oldAggCall.type, rexList,
+                oldAggCall.ordinal, oldAggCall.distinct, 
oldAggCall.ignoreNulls));
+      } else {
+        newAggCallWindow.add(oldAggCall);
+      }
+    }
+    if (aggCallChanged) {
+      return new Window.Group(oldWindowGroup.keys, oldWindowGroup.isRows, 
oldWindowGroup.lowerBound,
+          oldWindowGroup.upperBound, oldWindowGroup.orderKeys, 
newAggCallWindow);
+    }
+    return oldWindowGroup;
+  }
+
   private void validateWindows(Window window) {
     int numGroups = window.groups.size();
     // For Phase 1 we only handle single window groups


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to