This is an automated email from the ASF dual-hosted git repository.

xiangfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 0dcad92828 Fix LEAD/LAG window function implementation (#13340)
0dcad92828 is described below

commit 0dcad928285fbb8c05c7444c0a65eedceed28ab3
Author: Xiang Fu <[email protected]>
AuthorDate: Sat Jun 8 01:26:42 2024 -0700

    Fix LEAD/LAG window function implementation (#13340)
---
 .../pinot/query/QueryEnvironmentTestBase.java      |  4 +
 .../runtime/operator/utils/AggregationUtils.java   |  4 +-
 .../window/value/LagValueWindowFunction.java       | 56 +++++++++++---
 .../window/value/LeadValueWindowFunction.java      | 50 ++++++++++--
 .../operator/WindowAggregateOperatorTest.java      | 90 ++++++++++++++++++++++
 5 files changed, 182 insertions(+), 22 deletions(-)

diff --git 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
index 6b3b8a3631..8e33ec1ef3 100644
--- 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
+++ 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
@@ -191,6 +191,10 @@ public class QueryEnvironmentTestBase {
         new Object[]{"SELECT RANK() OVER(PARTITION BY a.col2 ORDER BY a.col1) 
FROM a"},
         new Object[]{"SELECT a.col1, LEAD(a.col3) OVER (PARTITION BY a.col2 
ORDER BY a.col3) FROM a"},
         new Object[]{"SELECT a.col1, LAG(a.col3) OVER (PARTITION BY a.col2 
ORDER BY a.col3) FROM a"},
+        new Object[]{"SELECT a.col1, LEAD(a.col3, 5) OVER (PARTITION BY a.col2 
ORDER BY a.col3) FROM a"},
+        new Object[]{"SELECT a.col1, LAG(a.col3, 5) OVER (PARTITION BY a.col2 
ORDER BY a.col3) FROM a"},
+        new Object[]{"SELECT a.col1, LEAD(a.col3, 5, -1) OVER (PARTITION BY 
a.col2 ORDER BY a.col3) FROM a"},
+        new Object[]{"SELECT a.col1, LAG(a.col3, 5, -1) OVER (PARTITION BY 
a.col2 ORDER BY a.col3) FROM a"},
         new Object[]{"SELECT DENSE_RANK() OVER(ORDER BY a.col1) FROM a"},
         new Object[]{"SELECT a.col1, SUM(a.col3) OVER (ORDER BY a.col2), 
MIN(a.col3) OVER (ORDER BY a.col2) FROM a"},
         new Object[]{
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java
index 0133843be0..ed24af5a3c 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java
@@ -18,7 +18,6 @@
  */
 package org.apache.pinot.query.runtime.operator.utils;
 
-import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableMap;
 import java.util.HashMap;
 import java.util.List;
@@ -223,8 +222,7 @@ public class AggregationUtils {
     private RexExpression 
toAggregationFunctionOperand(RexExpression.FunctionCall aggCall) {
       List<RexExpression> functionOperands = aggCall.getFunctionOperands();
       int numOperands = functionOperands.size();
-      Preconditions.checkState(numOperands < 2, "Aggregate functions cannot 
have more than one operand");
-      return numOperands == 1 ? functionOperands.get(0) : new 
RexExpression.Literal(ColumnDataType.INT, 1);
+      return numOperands == 0 ? new RexExpression.Literal(ColumnDataType.INT, 
1) : functionOperands.get(0);
     }
   }
 }
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LagValueWindowFunction.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LagValueWindowFunction.java
index 63c49bac44..797fca9313 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LagValueWindowFunction.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LagValueWindowFunction.java
@@ -18,32 +18,66 @@
  */
 package org.apache.pinot.query.runtime.operator.window.value;
 
-import java.util.ArrayList;
+import com.google.common.base.Preconditions;
+import java.util.Arrays;
 import java.util.List;
 import org.apache.calcite.rel.RelFieldCollation;
 import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.common.utils.PinotDataType;
 import org.apache.pinot.query.planner.logical.RexExpression;
 
 
 public class LagValueWindowFunction extends ValueWindowFunction {
+  private final int _offset;
+  private final Object _defaultValue;
 
   public LagValueWindowFunction(RexExpression.FunctionCall aggCall, DataSchema 
inputSchema,
       List<RelFieldCollation> collations, boolean partitionByOnly) {
     super(aggCall, inputSchema, collations, partitionByOnly);
+    int offset = 1;
+    Object defaultValue = null;
+    List<RexExpression> operands = aggCall.getFunctionOperands();
+    int numOperands = operands.size();
+    if (numOperands > 1) {
+      RexExpression secondOperand = operands.get(1);
+      Preconditions.checkArgument(secondOperand instanceof 
RexExpression.Literal,
+          "Second operand (offset) of LAG function must be a literal");
+      Object offsetValue = ((RexExpression.Literal) secondOperand).getValue();
+      if (offsetValue instanceof Number) {
+        offset = ((Number) offsetValue).intValue();
+      }
+    }
+    if (numOperands == 3) {
+      RexExpression thirdOperand = operands.get(2);
+      Preconditions.checkArgument(thirdOperand instanceof 
RexExpression.Literal,
+          "Third operand (default value) of LAG function must be a literal");
+      RexExpression.Literal defaultValueLiteral = (RexExpression.Literal) 
thirdOperand;
+      defaultValue = defaultValueLiteral.getValue();
+      if (defaultValue != null) {
+        DataSchema.ColumnDataType srcDataType = 
defaultValueLiteral.getDataType();
+        DataSchema.ColumnDataType destDataType = 
inputSchema.getColumnDataType(0);
+        if (srcDataType != destDataType) {
+          // Convert the default value to the same data type as the input 
column
+          // (e.g. convert INT to LONG, FLOAT to DOUBLE, etc.
+          defaultValue = 
PinotDataType.getPinotDataTypeForExecution(destDataType)
+              .convert(defaultValue, 
PinotDataType.getPinotDataTypeForExecution(srcDataType));
+        }
+      }
+    }
+    _offset = offset;
+    _defaultValue = defaultValue;
   }
 
   @Override
   public List<Object> processRows(List<Object[]> rows) {
-    List<Object> result = new ArrayList<>(rows.size());
-    Object[] prevRow = null;
-    for (Object[] row : rows) {
-      if (prevRow == null) {
-        result.add(null);
-      } else {
-        result.add(extractValueFromRow(prevRow));
-      }
-      prevRow = row;
+    int numRows = rows.size();
+    Object[] result = new Object[numRows];
+    if (_defaultValue != null) {
+      Arrays.fill(result, 0, _offset, _defaultValue);
+    }
+    for (int i = _offset; i < numRows; i++) {
+      result[i] = extractValueFromRow(rows.get(i - _offset));
     }
-    return result;
+    return Arrays.asList(result);
   }
 }
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LeadValueWindowFunction.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LeadValueWindowFunction.java
index 530675844c..099c3fba5f 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LeadValueWindowFunction.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LeadValueWindowFunction.java
@@ -18,32 +18,66 @@
  */
 package org.apache.pinot.query.runtime.operator.window.value;
 
+import com.google.common.base.Preconditions;
 import java.util.Arrays;
 import java.util.List;
 import org.apache.calcite.rel.RelFieldCollation;
 import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.common.utils.PinotDataType;
 import org.apache.pinot.query.planner.logical.RexExpression;
 
 
 public class LeadValueWindowFunction extends ValueWindowFunction {
 
+  private final int _offset;
+  private final Object _defaultValue;
+
   public LeadValueWindowFunction(RexExpression.FunctionCall aggCall, 
DataSchema inputSchema,
       List<RelFieldCollation> collations, boolean partitionByOnly) {
     super(aggCall, inputSchema, collations, partitionByOnly);
+    int offset = 1;
+    Object defaultValue = null;
+    List<RexExpression> operands = aggCall.getFunctionOperands();
+    int numOperands = operands.size();
+    if (numOperands > 1) {
+      RexExpression secondOperand = operands.get(1);
+      Preconditions.checkArgument(secondOperand instanceof 
RexExpression.Literal,
+          "Second operand (offset) of LAG function must be a literal");
+      Object offsetValue = ((RexExpression.Literal) secondOperand).getValue();
+      if (offsetValue instanceof Number) {
+        offset = ((Number) offsetValue).intValue();
+      }
+    }
+    if (numOperands == 3) {
+      RexExpression thirdOperand = operands.get(2);
+      Preconditions.checkArgument(thirdOperand instanceof 
RexExpression.Literal,
+          "Third operand (default value) of LAG function must be a literal");
+      RexExpression.Literal defaultValueLiteral = (RexExpression.Literal) 
thirdOperand;
+      defaultValue = defaultValueLiteral.getValue();
+      if (defaultValue != null) {
+        DataSchema.ColumnDataType srcDataType = 
defaultValueLiteral.getDataType();
+        DataSchema.ColumnDataType destDataType = 
inputSchema.getColumnDataType(0);
+        if (srcDataType != destDataType) {
+          // Convert the default value to the same data type as the input 
column
+          // (e.g. convert INT to LONG, FLOAT to DOUBLE, etc.
+          defaultValue = 
PinotDataType.getPinotDataTypeForExecution(destDataType)
+              .convert(defaultValue, 
PinotDataType.getPinotDataTypeForExecution(srcDataType));
+        }
+      }
+    }
+    _offset = offset;
+    _defaultValue = defaultValue;
   }
 
   @Override
   public List<Object> processRows(List<Object[]> rows) {
     int numRows = rows.size();
     Object[] result = new Object[numRows];
-    Object[] nextRow = null;
-    for (int i = numRows - 1; i >= 0; i--) {
-      if (nextRow == null) {
-        result[i] = null;
-      } else {
-        result[i] = extractValueFromRow(nextRow);
-      }
-      nextRow = rows.get(i);
+    for (int i = 0; i < numRows - _offset; i++) {
+      result[i] = extractValueFromRow(rows.get(i + _offset));
+    }
+    if (_defaultValue != null) {
+      Arrays.fill(result, numRows - _offset, numRows, _defaultValue);
     }
     return Arrays.asList(result);
   }
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java
index f351c94f25..cd3ca26b6c 100644
--- 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java
@@ -561,6 +561,96 @@ public class WindowAggregateOperatorTest {
         "Max rows in window should be reached");
   }
 
+  @Test
+  public void testLeadLagWindowFunction() {
+    // Given:
+    DataSchema inputSchema = new DataSchema(new String[]{"group", "arg"}, new 
ColumnDataType[]{INT, STRING});
+    when(_input.nextBlock()).thenReturn(
+            OperatorTestUtil.block(inputSchema, new Object[]{3, "and"}, new 
Object[]{2, "bar"}, new Object[]{2, "foo"},
+                new Object[]{1, "foo"})).thenReturn(
+            OperatorTestUtil.block(inputSchema, new Object[]{1, "foo"}, new 
Object[]{2, "foo"}, new Object[]{1, "numb"},
+                new Object[]{2, "the"}, new Object[]{3, "true"}))
+        
.thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
+    DataSchema resultSchema = new DataSchema(new String[]{"group", "arg", 
"lead", "lag"},
+        new ColumnDataType[]{INT, STRING, INT, INT});
+    List<Integer> keys = List.of(0);
+    List<RelFieldCollation> collations =
+        List.of(new RelFieldCollation(1, 
RelFieldCollation.Direction.ASCENDING, RelFieldCollation.NullDirection.LAST));
+    List<RexExpression.FunctionCall> aggCalls =
+        List.of(new RexExpression.FunctionCall(ColumnDataType.INT, 
SqlKind.LEAD.name(),
+                List.of(new RexExpression.InputRef(0), new 
RexExpression.Literal(ColumnDataType.INT, 1))),
+            new RexExpression.FunctionCall(ColumnDataType.INT, 
SqlKind.LAG.name(),
+                List.of(new RexExpression.InputRef(0), new 
RexExpression.Literal(ColumnDataType.INT, 1))));
+    WindowAggregateOperator operator =
+        getOperator(inputSchema, resultSchema, keys, collations, aggCalls, 
WindowNode.WindowFrameType.RANGE,
+            Integer.MIN_VALUE, 0);
+
+    // When:
+    List<Object[]> resultRows = operator.nextBlock().getContainer();
+    // Then:
+    verifyResultRows(resultRows, keys, Map.of(
+        1, List.of(
+            new Object[]{1, "foo", 1, null},
+            new Object[]{1, "foo", 1, 1},
+            new Object[]{1, "numb", null, 1}),
+        2, List.of(
+            new Object[]{2, "bar", 2, null},
+            new Object[]{2, "foo", 2, 2},
+            new Object[]{2, "foo", 2, 2},
+            new Object[]{2, "the", null, 2}),
+        3, List.of(
+            new Object[]{3, "and", 3, null},
+            new Object[]{3, "true", null, 3})
+    ));
+    assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second 
block is EOS (done processing)");
+  }
+
+  @Test
+  public void testLeadLagWindowFunction2() {
+    // Given:
+    DataSchema inputSchema = new DataSchema(new String[]{"group", "arg"}, new 
ColumnDataType[]{INT, STRING});
+    when(_input.nextBlock()).thenReturn(
+            OperatorTestUtil.block(inputSchema, new Object[]{3, "and"}, new 
Object[]{2, "bar"}, new Object[]{2, "foo"},
+                new Object[]{1, "foo"})).thenReturn(
+            OperatorTestUtil.block(inputSchema, new Object[]{1, "foo"}, new 
Object[]{2, "foo"}, new Object[]{1, "numb"},
+                new Object[]{2, "the"}, new Object[]{3, "true"}))
+        
.thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
+    DataSchema resultSchema = new DataSchema(new String[]{"group", "arg", 
"lead", "lag"},
+        new ColumnDataType[]{INT, STRING, INT, INT});
+    List<Integer> keys = List.of(0);
+    List<RelFieldCollation> collations =
+        List.of(new RelFieldCollation(1, 
RelFieldCollation.Direction.ASCENDING, RelFieldCollation.NullDirection.LAST));
+    List<RexExpression.FunctionCall> aggCalls =
+        List.of(new RexExpression.FunctionCall(ColumnDataType.INT, 
SqlKind.LEAD.name(),
+                List.of(new RexExpression.InputRef(0), new 
RexExpression.Literal(ColumnDataType.INT, 2),
+                    new RexExpression.Literal(ColumnDataType.INT, 100))),
+            new RexExpression.FunctionCall(ColumnDataType.INT, 
SqlKind.LAG.name(),
+                List.of(new RexExpression.InputRef(0), new 
RexExpression.Literal(ColumnDataType.INT, 1),
+                    new RexExpression.Literal(ColumnDataType.INT, 200))));
+    WindowAggregateOperator operator =
+        getOperator(inputSchema, resultSchema, keys, collations, aggCalls, 
WindowNode.WindowFrameType.RANGE,
+            Integer.MIN_VALUE, 0);
+
+    // When:
+    List<Object[]> resultRows = operator.nextBlock().getContainer();
+    // Then:
+    verifyResultRows(resultRows, keys, Map.of(
+        1, List.of(
+            new Object[]{1, "foo", 1, 200},
+            new Object[]{1, "foo", 100, 1},
+            new Object[]{1, "numb", 100, 1}),
+        2, List.of(
+            new Object[]{2, "bar", 2, 200},
+            new Object[]{2, "foo", 2, 2},
+            new Object[]{2, "foo", 100, 2},
+            new Object[]{2, "the", 100, 2}),
+        3, List.of(
+            new Object[]{3, "and", 100, 200},
+            new Object[]{3, "true", 100, 3})
+    ));
+    assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second 
block is EOS (done processing)");
+  }
+
   private WindowAggregateOperator getOperator(DataSchema inputSchema, 
DataSchema resultSchema, List<Integer> keys,
       List<RelFieldCollation> collations, List<RexExpression.FunctionCall> 
aggCalls,
       WindowNode.WindowFrameType windowFrameType, int lowerBound, int 
upperBound, PlanNode.NodeHint nodeHint) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to