This is an automated email from the ASF dual-hosted git repository.
xiangfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 55b6024467 Fix literal handling in Window functions (#13428)
55b6024467 is described below
commit 55b6024467e5e4ad8f4a260453054c42eaac706d
Author: Xiang Fu <[email protected]>
AuthorDate: Tue Jun 18 18:10:55 2024 -0700
Fix literal handling in Window functions (#13428)
---
.../tests/OfflineClusterIntegrationTest.java | 45 +++++++++++++++++++
.../rules/PinotWindowExchangeNodeInsertRule.java | 51 +++++++++++++++++++---
2 files changed, 91 insertions(+), 5 deletions(-)
diff --git
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
index 49bf22b8e8..94145bb486 100644
---
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
+++
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
@@ -2113,6 +2113,51 @@ public class OfflineClusterIntegrationTest extends
BaseClusterIntegrationTestSet
assertEquals(row.get(0).asDouble(), 16071.0 / 2);
}
+ @Test
+ public void testWindowAggregationV2()
+ throws Exception {
+ setUseMultiStageQueryEngine(true);
+ String tmpTableQuery =
+ "select DaysSinceEpoch, count(*) as num_trips from mytable GROUP BY
DaysSinceEpoch order by DaysSinceEpoch";
+ JsonNode tmpTableResult =
postQuery(tmpTableQuery).get("resultTable").get("rows");
+
+ String query = "WITH tmp AS (\n"
+ + " select count(*) as num_trips, DaysSinceEpoch from mytable GROUP
BY DaysSinceEpoch\n"
+ + ")\n"
+ + "\n"
+ + "SELECT\n"
+ + " DaysSinceEpoch,\n"
+ + " num_trips,\n"
+ + " LAG(num_trips, 2) OVER (ORDER BY DaysSinceEpoch) AS
previous_num_trips,\n"
+ + " num_trips - LAG(num_trips, 2) OVER (ORDER BY DaysSinceEpoch) AS
difference\n"
+ + "FROM\n"
+ + " tmp";
+ JsonNode response = postQuery(query);
+ JsonNode resultTable = response.get("resultTable");
+
assertEquals(resultTable.get("dataSchema").get("columnDataTypes").toString(),
+ "[\"INT\",\"LONG\",\"LONG\",\"LONG\"]");
+ JsonNode rows = resultTable.get("rows");
+ assertEquals(rows.size(), 364);
+ for (int i = 0; i < 2; i++) {
+ JsonNode row = rows.get(i);
+ JsonNode tmpTableRow = tmpTableResult.get(i);
+ assertEquals(row.size(), 4);
+ assertEquals(row.get(0).asInt(), tmpTableRow.get(0).asInt());
+ assertEquals(row.get(1).asLong(), tmpTableRow.get(1).asLong());
+ assertTrue(row.get(2).isNull());
+ assertTrue(row.get(2).isNull());
+ }
+ for (int i = 2; i < 363; i++) {
+ JsonNode row = rows.get(i);
+ assertEquals(row.size(), 4);
+ JsonNode tmpTableRow = tmpTableResult.get(i);
+ assertEquals(row.get(0).asInt(), tmpTableRow.get(0).asInt());
+ assertEquals(row.get(1).asLong(), tmpTableRow.get(1).asLong());
+ assertEquals(rows.get(i - 2).get(1).asLong(), row.get(2).asLong());
+ assertEquals(row.get(1).asLong() - row.get(2).asLong(),
row.get(3).asLong());
+ }
+ }
+
@Test(dataProvider = "useBothQueryEngines")
public void testSelectionUDF(boolean useMultiStageQueryEngine)
throws Exception {
diff --git
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotWindowExchangeNodeInsertRule.java
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotWindowExchangeNodeInsertRule.java
index e9caf1216f..c004aba293 100644
---
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotWindowExchangeNodeInsertRule.java
+++
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotWindowExchangeNodeInsertRule.java
@@ -41,6 +41,7 @@ import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilderFactory;
@@ -92,7 +93,7 @@ public class PinotWindowExchangeNodeInsertRule extends
RelOptRule {
// Perform all validations
validateWindows(window);
- Window.Group windowGroup = window.groups.get(0);
+ Window.Group windowGroup = updateLiteralArgumentsInWindowGroup(window);
if (windowGroup.keys.isEmpty() &&
windowGroup.orderKeys.getKeys().isEmpty()) {
// Empty OVER()
// Add a single Exchange for empty OVER() since no sort is required
@@ -111,7 +112,8 @@ public class PinotWindowExchangeNodeInsertRule extends
RelOptRule {
PinotLogicalExchange exchange = PinotLogicalExchange.create(windowInput,
RelDistributions.hash(Collections.emptyList()));
call.transformTo(
- LogicalWindow.create(window.getTraitSet(), exchange,
window.constants, window.getRowType(), window.groups));
+ LogicalWindow.create(window.getTraitSet(), exchange,
window.constants, window.getRowType(),
+ List.of(windowGroup)));
} else if (windowGroup.keys.isEmpty() &&
!windowGroup.orderKeys.getKeys().isEmpty()) {
// Only ORDER BY
// Add a LogicalSortExchange with collation on the order by key(s) and
an empty hash partition key
@@ -121,7 +123,7 @@ public class PinotWindowExchangeNodeInsertRule extends
RelOptRule {
PinotLogicalSortExchange sortExchange =
PinotLogicalSortExchange.create(windowInput,
RelDistributions.hash(Collections.emptyList()),
windowGroup.orderKeys, false, true);
call.transformTo(LogicalWindow.create(window.getTraitSet(),
sortExchange, window.constants, window.getRowType(),
- window.groups));
+ List.of(windowGroup)));
} else {
// All other variants
// Assess whether this is a PARTITION BY only query or not (includes
queries of the type where PARTITION BY and
@@ -134,7 +136,7 @@ public class PinotWindowExchangeNodeInsertRule extends
RelOptRule {
PinotLogicalExchange exchange =
PinotLogicalExchange.create(windowInput,
RelDistributions.hash(windowGroup.keys.toList()));
call.transformTo(LogicalWindow.create(window.getTraitSet(), exchange,
window.constants, window.getRowType(),
- window.groups));
+ List.of(windowGroup)));
} else {
// PARTITION BY and ORDER BY on different key(s)
// Add a LogicalSortExchange hashed on the partition by keys and
collation based on order by keys
@@ -145,11 +147,50 @@ public class PinotWindowExchangeNodeInsertRule extends
RelOptRule {
PinotLogicalSortExchange sortExchange =
PinotLogicalSortExchange.create(windowInput,
RelDistributions.hash(windowGroup.keys.toList()),
windowGroup.orderKeys, false, true);
call.transformTo(LogicalWindow.create(window.getTraitSet(),
sortExchange, window.constants, window.getRowType(),
- window.groups));
+ List.of(windowGroup)));
}
}
}
+ private Window.Group updateLiteralArgumentsInWindowGroup(Window window) {
+ Window.Group oldWindowGroup = window.groups.get(0);
+ int windowInputSize = window.getInput().getRowType().getFieldCount();
+ ImmutableList<Window.RexWinAggCall> oldAggCalls = oldWindowGroup.aggCalls;
+ List<Window.RexWinAggCall> newAggCallWindow = new
ArrayList<>(oldAggCalls.size());
+ boolean aggCallChanged = false;
+ for (Window.RexWinAggCall oldAggCall : oldAggCalls) {
+ boolean changed = false;
+ List<RexNode> oldAggCallArgList = oldAggCall.getOperands();
+ List<RexNode> rexList = new ArrayList<>(oldAggCallArgList.size());
+ for (RexNode rexNode : oldAggCallArgList) {
+ RexNode newRexNode = rexNode;
+ if (rexNode instanceof RexInputRef) {
+ RexInputRef inputRef = (RexInputRef) rexNode;
+ int inputRefIndex = inputRef.getIndex();
+ // If the input reference is greater than the window input size, it
is a reference to the constants
+ if (inputRefIndex >= windowInputSize) {
+ newRexNode = window.constants.get(inputRefIndex - windowInputSize);
+ changed = true;
+ aggCallChanged = true;
+ }
+ }
+ rexList.add(newRexNode);
+ }
+ if (changed) {
+ newAggCallWindow.add(
+ new Window.RexWinAggCall((SqlAggFunction)
oldAggCall.getOperator(), oldAggCall.type, rexList,
+ oldAggCall.ordinal, oldAggCall.distinct,
oldAggCall.ignoreNulls));
+ } else {
+ newAggCallWindow.add(oldAggCall);
+ }
+ }
+ if (aggCallChanged) {
+ return new Window.Group(oldWindowGroup.keys, oldWindowGroup.isRows,
oldWindowGroup.lowerBound,
+ oldWindowGroup.upperBound, oldWindowGroup.orderKeys,
newAggCallWindow);
+ }
+ return oldWindowGroup;
+ }
+
private void validateWindows(Window window) {
int numGroups = window.groups.size();
// For Phase 1 we only handle single window groups
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]