This is an automated email from the ASF dual-hosted git repository.
xiangfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 0dcad92828 Fix LEAD/LAG window function implementation (#13340)
0dcad92828 is described below
commit 0dcad928285fbb8c05c7444c0a65eedceed28ab3
Author: Xiang Fu <[email protected]>
AuthorDate: Sat Jun 8 01:26:42 2024 -0700
Fix LEAD/LAG window function implementation (#13340)
---
.../pinot/query/QueryEnvironmentTestBase.java | 4 +
.../runtime/operator/utils/AggregationUtils.java | 4 +-
.../window/value/LagValueWindowFunction.java | 56 +++++++++++---
.../window/value/LeadValueWindowFunction.java | 50 ++++++++++--
.../operator/WindowAggregateOperatorTest.java | 90 ++++++++++++++++++++++
5 files changed, 182 insertions(+), 22 deletions(-)
diff --git
a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
index 6b3b8a3631..8e33ec1ef3 100644
---
a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
+++
b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
@@ -191,6 +191,10 @@ public class QueryEnvironmentTestBase {
new Object[]{"SELECT RANK() OVER(PARTITION BY a.col2 ORDER BY a.col1)
FROM a"},
new Object[]{"SELECT a.col1, LEAD(a.col3) OVER (PARTITION BY a.col2
ORDER BY a.col3) FROM a"},
new Object[]{"SELECT a.col1, LAG(a.col3) OVER (PARTITION BY a.col2
ORDER BY a.col3) FROM a"},
+ new Object[]{"SELECT a.col1, LEAD(a.col3, 5) OVER (PARTITION BY a.col2
ORDER BY a.col3) FROM a"},
+ new Object[]{"SELECT a.col1, LAG(a.col3, 5) OVER (PARTITION BY a.col2
ORDER BY a.col3) FROM a"},
+ new Object[]{"SELECT a.col1, LEAD(a.col3, 5, -1) OVER (PARTITION BY
a.col2 ORDER BY a.col3) FROM a"},
+ new Object[]{"SELECT a.col1, LAG(a.col3, 5, -1) OVER (PARTITION BY
a.col2 ORDER BY a.col3) FROM a"},
new Object[]{"SELECT DENSE_RANK() OVER(ORDER BY a.col1) FROM a"},
new Object[]{"SELECT a.col1, SUM(a.col3) OVER (ORDER BY a.col2),
MIN(a.col3) OVER (ORDER BY a.col2) FROM a"},
new Object[]{
diff --git
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java
index 0133843be0..ed24af5a3c 100644
---
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java
+++
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java
@@ -18,7 +18,6 @@
*/
package org.apache.pinot.query.runtime.operator.utils;
-import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import java.util.HashMap;
import java.util.List;
@@ -223,8 +222,7 @@ public class AggregationUtils {
private RexExpression
toAggregationFunctionOperand(RexExpression.FunctionCall aggCall) {
List<RexExpression> functionOperands = aggCall.getFunctionOperands();
int numOperands = functionOperands.size();
- Preconditions.checkState(numOperands < 2, "Aggregate functions cannot
have more than one operand");
- return numOperands == 1 ? functionOperands.get(0) : new
RexExpression.Literal(ColumnDataType.INT, 1);
+ return numOperands == 0 ? new RexExpression.Literal(ColumnDataType.INT,
1) : functionOperands.get(0);
}
}
}
diff --git
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LagValueWindowFunction.java
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LagValueWindowFunction.java
index 63c49bac44..797fca9313 100644
---
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LagValueWindowFunction.java
+++
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LagValueWindowFunction.java
@@ -18,32 +18,66 @@
*/
package org.apache.pinot.query.runtime.operator.window.value;
-import java.util.ArrayList;
+import com.google.common.base.Preconditions;
+import java.util.Arrays;
import java.util.List;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.common.utils.PinotDataType;
import org.apache.pinot.query.planner.logical.RexExpression;
public class LagValueWindowFunction extends ValueWindowFunction {
+ private final int _offset;
+ private final Object _defaultValue;
public LagValueWindowFunction(RexExpression.FunctionCall aggCall, DataSchema
inputSchema,
List<RelFieldCollation> collations, boolean partitionByOnly) {
super(aggCall, inputSchema, collations, partitionByOnly);
+ int offset = 1;
+ Object defaultValue = null;
+ List<RexExpression> operands = aggCall.getFunctionOperands();
+ int numOperands = operands.size();
+ if (numOperands > 1) {
+ RexExpression secondOperand = operands.get(1);
+ Preconditions.checkArgument(secondOperand instanceof
RexExpression.Literal,
+ "Second operand (offset) of LAG function must be a literal");
+ Object offsetValue = ((RexExpression.Literal) secondOperand).getValue();
+ if (offsetValue instanceof Number) {
+ offset = ((Number) offsetValue).intValue();
+ }
+ }
+ if (numOperands == 3) {
+ RexExpression thirdOperand = operands.get(2);
+ Preconditions.checkArgument(thirdOperand instanceof
RexExpression.Literal,
+ "Third operand (default value) of LAG function must be a literal");
+ RexExpression.Literal defaultValueLiteral = (RexExpression.Literal)
thirdOperand;
+ defaultValue = defaultValueLiteral.getValue();
+ if (defaultValue != null) {
+ DataSchema.ColumnDataType srcDataType =
defaultValueLiteral.getDataType();
+ DataSchema.ColumnDataType destDataType =
inputSchema.getColumnDataType(0);
+ if (srcDataType != destDataType) {
+ // Convert the default value to the same data type as the input
column
+ // (e.g. convert INT to LONG, FLOAT to DOUBLE, etc.
+ defaultValue =
PinotDataType.getPinotDataTypeForExecution(destDataType)
+ .convert(defaultValue,
PinotDataType.getPinotDataTypeForExecution(srcDataType));
+ }
+ }
+ }
+ _offset = offset;
+ _defaultValue = defaultValue;
}
@Override
public List<Object> processRows(List<Object[]> rows) {
- List<Object> result = new ArrayList<>(rows.size());
- Object[] prevRow = null;
- for (Object[] row : rows) {
- if (prevRow == null) {
- result.add(null);
- } else {
- result.add(extractValueFromRow(prevRow));
- }
- prevRow = row;
+ int numRows = rows.size();
+ Object[] result = new Object[numRows];
+ if (_defaultValue != null) {
+ Arrays.fill(result, 0, _offset, _defaultValue);
+ }
+ for (int i = _offset; i < numRows; i++) {
+ result[i] = extractValueFromRow(rows.get(i - _offset));
}
- return result;
+ return Arrays.asList(result);
}
}
diff --git
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LeadValueWindowFunction.java
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LeadValueWindowFunction.java
index 530675844c..099c3fba5f 100644
---
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LeadValueWindowFunction.java
+++
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LeadValueWindowFunction.java
@@ -18,32 +18,66 @@
*/
package org.apache.pinot.query.runtime.operator.window.value;
+import com.google.common.base.Preconditions;
import java.util.Arrays;
import java.util.List;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.common.utils.PinotDataType;
import org.apache.pinot.query.planner.logical.RexExpression;
public class LeadValueWindowFunction extends ValueWindowFunction {
+ private final int _offset;
+ private final Object _defaultValue;
+
public LeadValueWindowFunction(RexExpression.FunctionCall aggCall,
DataSchema inputSchema,
List<RelFieldCollation> collations, boolean partitionByOnly) {
super(aggCall, inputSchema, collations, partitionByOnly);
+ int offset = 1;
+ Object defaultValue = null;
+ List<RexExpression> operands = aggCall.getFunctionOperands();
+ int numOperands = operands.size();
+ if (numOperands > 1) {
+ RexExpression secondOperand = operands.get(1);
+ Preconditions.checkArgument(secondOperand instanceof
RexExpression.Literal,
+ "Second operand (offset) of LAG function must be a literal");
+ Object offsetValue = ((RexExpression.Literal) secondOperand).getValue();
+ if (offsetValue instanceof Number) {
+ offset = ((Number) offsetValue).intValue();
+ }
+ }
+ if (numOperands == 3) {
+ RexExpression thirdOperand = operands.get(2);
+ Preconditions.checkArgument(thirdOperand instanceof
RexExpression.Literal,
+ "Third operand (default value) of LAG function must be a literal");
+ RexExpression.Literal defaultValueLiteral = (RexExpression.Literal)
thirdOperand;
+ defaultValue = defaultValueLiteral.getValue();
+ if (defaultValue != null) {
+ DataSchema.ColumnDataType srcDataType =
defaultValueLiteral.getDataType();
+ DataSchema.ColumnDataType destDataType =
inputSchema.getColumnDataType(0);
+ if (srcDataType != destDataType) {
+ // Convert the default value to the same data type as the input
column
+ // (e.g. convert INT to LONG, FLOAT to DOUBLE, etc.
+ defaultValue =
PinotDataType.getPinotDataTypeForExecution(destDataType)
+ .convert(defaultValue,
PinotDataType.getPinotDataTypeForExecution(srcDataType));
+ }
+ }
+ }
+ _offset = offset;
+ _defaultValue = defaultValue;
}
@Override
public List<Object> processRows(List<Object[]> rows) {
int numRows = rows.size();
Object[] result = new Object[numRows];
- Object[] nextRow = null;
- for (int i = numRows - 1; i >= 0; i--) {
- if (nextRow == null) {
- result[i] = null;
- } else {
- result[i] = extractValueFromRow(nextRow);
- }
- nextRow = rows.get(i);
+ for (int i = 0; i < numRows - _offset; i++) {
+ result[i] = extractValueFromRow(rows.get(i + _offset));
+ }
+ if (_defaultValue != null) {
+ Arrays.fill(result, numRows - _offset, numRows, _defaultValue);
}
return Arrays.asList(result);
}
diff --git
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java
index f351c94f25..cd3ca26b6c 100644
---
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java
+++
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java
@@ -561,6 +561,96 @@ public class WindowAggregateOperatorTest {
"Max rows in window should be reached");
}
+ @Test
+ public void testLeadLagWindowFunction() {
+ // Given:
+ DataSchema inputSchema = new DataSchema(new String[]{"group", "arg"}, new
ColumnDataType[]{INT, STRING});
+ when(_input.nextBlock()).thenReturn(
+ OperatorTestUtil.block(inputSchema, new Object[]{3, "and"}, new
Object[]{2, "bar"}, new Object[]{2, "foo"},
+ new Object[]{1, "foo"})).thenReturn(
+ OperatorTestUtil.block(inputSchema, new Object[]{1, "foo"}, new
Object[]{2, "foo"}, new Object[]{1, "numb"},
+ new Object[]{2, "the"}, new Object[]{3, "true"}))
+
.thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
+ DataSchema resultSchema = new DataSchema(new String[]{"group", "arg",
"lead", "lag"},
+ new ColumnDataType[]{INT, STRING, INT, INT});
+ List<Integer> keys = List.of(0);
+ List<RelFieldCollation> collations =
+ List.of(new RelFieldCollation(1,
RelFieldCollation.Direction.ASCENDING, RelFieldCollation.NullDirection.LAST));
+ List<RexExpression.FunctionCall> aggCalls =
+ List.of(new RexExpression.FunctionCall(ColumnDataType.INT,
SqlKind.LEAD.name(),
+ List.of(new RexExpression.InputRef(0), new
RexExpression.Literal(ColumnDataType.INT, 1))),
+ new RexExpression.FunctionCall(ColumnDataType.INT,
SqlKind.LAG.name(),
+ List.of(new RexExpression.InputRef(0), new
RexExpression.Literal(ColumnDataType.INT, 1))));
+ WindowAggregateOperator operator =
+ getOperator(inputSchema, resultSchema, keys, collations, aggCalls,
WindowNode.WindowFrameType.RANGE,
+ Integer.MIN_VALUE, 0);
+
+ // When:
+ List<Object[]> resultRows = operator.nextBlock().getContainer();
+ // Then:
+ verifyResultRows(resultRows, keys, Map.of(
+ 1, List.of(
+ new Object[]{1, "foo", 1, null},
+ new Object[]{1, "foo", 1, 1},
+ new Object[]{1, "numb", null, 1}),
+ 2, List.of(
+ new Object[]{2, "bar", 2, null},
+ new Object[]{2, "foo", 2, 2},
+ new Object[]{2, "foo", 2, 2},
+ new Object[]{2, "the", null, 2}),
+ 3, List.of(
+ new Object[]{3, "and", 3, null},
+ new Object[]{3, "true", null, 3})
+ ));
+ assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second
block is EOS (done processing)");
+ }
+
+ @Test
+ public void testLeadLagWindowFunction2() {
+ // Given:
+ DataSchema inputSchema = new DataSchema(new String[]{"group", "arg"}, new
ColumnDataType[]{INT, STRING});
+ when(_input.nextBlock()).thenReturn(
+ OperatorTestUtil.block(inputSchema, new Object[]{3, "and"}, new
Object[]{2, "bar"}, new Object[]{2, "foo"},
+ new Object[]{1, "foo"})).thenReturn(
+ OperatorTestUtil.block(inputSchema, new Object[]{1, "foo"}, new
Object[]{2, "foo"}, new Object[]{1, "numb"},
+ new Object[]{2, "the"}, new Object[]{3, "true"}))
+
.thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
+ DataSchema resultSchema = new DataSchema(new String[]{"group", "arg",
"lead", "lag"},
+ new ColumnDataType[]{INT, STRING, INT, INT});
+ List<Integer> keys = List.of(0);
+ List<RelFieldCollation> collations =
+ List.of(new RelFieldCollation(1,
RelFieldCollation.Direction.ASCENDING, RelFieldCollation.NullDirection.LAST));
+ List<RexExpression.FunctionCall> aggCalls =
+ List.of(new RexExpression.FunctionCall(ColumnDataType.INT,
SqlKind.LEAD.name(),
+ List.of(new RexExpression.InputRef(0), new
RexExpression.Literal(ColumnDataType.INT, 2),
+ new RexExpression.Literal(ColumnDataType.INT, 100))),
+ new RexExpression.FunctionCall(ColumnDataType.INT,
SqlKind.LAG.name(),
+ List.of(new RexExpression.InputRef(0), new
RexExpression.Literal(ColumnDataType.INT, 1),
+ new RexExpression.Literal(ColumnDataType.INT, 200))));
+ WindowAggregateOperator operator =
+ getOperator(inputSchema, resultSchema, keys, collations, aggCalls,
WindowNode.WindowFrameType.RANGE,
+ Integer.MIN_VALUE, 0);
+
+ // When:
+ List<Object[]> resultRows = operator.nextBlock().getContainer();
+ // Then:
+ verifyResultRows(resultRows, keys, Map.of(
+ 1, List.of(
+ new Object[]{1, "foo", 1, 200},
+ new Object[]{1, "foo", 100, 1},
+ new Object[]{1, "numb", 100, 1}),
+ 2, List.of(
+ new Object[]{2, "bar", 2, 200},
+ new Object[]{2, "foo", 2, 2},
+ new Object[]{2, "foo", 100, 2},
+ new Object[]{2, "the", 100, 2}),
+ 3, List.of(
+ new Object[]{3, "and", 100, 200},
+ new Object[]{3, "true", 100, 3})
+ ));
+ assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second
block is EOS (done processing)");
+ }
+
private WindowAggregateOperator getOperator(DataSchema inputSchema,
DataSchema resultSchema, List<Integer> keys,
List<RelFieldCollation> collations, List<RexExpression.FunctionCall>
aggCalls,
WindowNode.WindowFrameType windowFrameType, int lowerBound, int
upperBound, PlanNode.NodeHint nodeHint) {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]