This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new d81ad57  [SPARK-38655][SQL] `OffsetWindowFunctionFrameBase` cannot 
find the offset row whose input is not-null
d81ad57 is described below

commit d81ad57a6b606ba2850bc0e1dec91bc28831a0ae
Author: Jiaan Geng <belie...@163.com>
AuthorDate: Mon Mar 28 14:17:39 2022 +0800

    [SPARK-38655][SQL] `OffsetWindowFunctionFrameBase` cannot find the offset 
row whose input is not-null
    
    ### What changes were proposed in this pull request?
    ```
    select x, nth_value(x, 5) IGNORE NULLS over (order by x rows between 
unbounded preceding and current row)
    from (select explode(sequence(1, 3)) x)
    ```
    The sql output:
    ```
    null
    null
    3
    ```
    But it should returns
    ```
    null
    null
    null
    ```
    
    ### Why are the changes needed?
    Fix the bug UnboundedPrecedingOffsetWindowFunctionFrame works not good.
    
    ### Does this PR introduce _any_ user-facing change?
    'Yes'.
    The output will be correct after fix this bug.
    
    ### How was this patch tested?
    New tests.
    
    Closes #35971 from beliefer/SPARK-38655.
    
    Authored-by: Jiaan Geng <belie...@163.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/execution/window/WindowFunctionFrame.scala | 26 +++++++++----
 .../spark/sql/DataFrameWindowFunctionsSuite.scala  | 44 ++++++++++++++++++++++
 2 files changed, 63 insertions(+), 7 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
index 7d08595..2b7f702 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
@@ -157,6 +157,9 @@ abstract class OffsetWindowFunctionFrameBase(
 
   /** find the offset row whose input is not null */
   protected def findNextRowWithNonNullInput(): Unit = {
+    // In order to find the offset row whose input is not-null,
+    // offset < = input.length must be guaranteed.
+    assert(offset <= input.length)
     while (skippedNonNullCount < offset && inputIndex < input.length) {
       val r = WindowFunctionFrame.getNextOrNull(inputIterator)
       if (!nullCheck(r)) {
@@ -165,6 +168,11 @@ abstract class OffsetWindowFunctionFrameBase(
       }
       inputIndex += 1
     }
+    if (skippedNonNullCount < offset && inputIndex == input.length) {
+      // The size of non-null input is less than offset, cannot find the 
offset row whose input
+      // is not null. Therefore, reset `nextSelectedRow` with empty row.
+      nextSelectedRow = EmptyRow
+    }
   }
 
   override def currentLowerBound(): Int = throw new 
UnsupportedOperationException()
@@ -362,14 +370,18 @@ class UnboundedPrecedingOffsetWindowFunctionFrame(
   assert(offset > 0)
 
   override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
-    resetStates(rows)
-    if (ignoreNulls) {
-      findNextRowWithNonNullInput()
+    if (offset > rows.length) {
+      fillDefaultValue(EmptyRow)
     } else {
-      // drain the first few rows if offset is larger than one
-      while (inputIndex < offset) {
-        nextSelectedRow = WindowFunctionFrame.getNextOrNull(inputIterator)
-        inputIndex += 1
+      resetStates(rows)
+      if (ignoreNulls) {
+        findNextRowWithNonNullInput()
+      } else {
+        // drain the first few rows if offset is larger than one
+        while (inputIndex < offset) {
+          nextSelectedRow = WindowFunctionFrame.getNextOrNull(inputIterator)
+          inputIndex += 1
+        }
       }
     }
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
index 666bf73..70b0150 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
@@ -702,6 +702,50 @@ class DataFrameWindowFunctionsSuite extends QueryTest
         Row("a", 4, "x", "x", "y", "x", "x", "y"),
         Row("b", 1, null, null, null, null, null, null),
         Row("b", 2, null, null, null, null, null, null)))
+
+    val df2 = Seq(
+      ("a", 1, "x"),
+      ("a", 2, "y"),
+      ("a", 3, "z")).
+      toDF("key", "order", "value")
+    checkAnswer(
+      df2.select(
+        $"key",
+        $"order",
+        nth_value($"value", 2).over(window1),
+        nth_value($"value", 2, ignoreNulls = true).over(window1),
+        nth_value($"value", 2).over(window2),
+        nth_value($"value", 2, ignoreNulls = true).over(window2),
+        nth_value($"value", 3).over(window1),
+        nth_value($"value", 3, ignoreNulls = true).over(window1),
+        nth_value($"value", 3).over(window2),
+        nth_value($"value", 3, ignoreNulls = true).over(window2),
+        nth_value($"value", 4).over(window1),
+        nth_value($"value", 4, ignoreNulls = true).over(window1),
+        nth_value($"value", 4).over(window2),
+        nth_value($"value", 4, ignoreNulls = true).over(window2)),
+      Seq(
+        Row("a", 1, "y", "y", null, null, "z", "z", null, null, null, null, 
null, null),
+        Row("a", 2, "y", "y", "y", "y", "z", "z", null, null, null, null, 
null, null),
+        Row("a", 3, "y", "y", "y", "y", "z", "z", "z", "z", null, null, null, 
null)))
+
+    val df3 = Seq(
+      ("a", 1, "x"),
+      ("a", 2, nullStr),
+      ("a", 3, "z")).
+      toDF("key", "order", "value")
+    checkAnswer(
+      df3.select(
+        $"key",
+        $"order",
+        nth_value($"value", 3).over(window1),
+        nth_value($"value", 3, ignoreNulls = true).over(window1),
+        nth_value($"value", 3).over(window2),
+        nth_value($"value", 3, ignoreNulls = true).over(window2)),
+      Seq(
+        Row("a", 1, "z", null, null, null),
+        Row("a", 2, "z", null, null, null),
+        Row("a", 3, "z", null, "z", null)))
   }
 
   test("nth_value on descending ordered window") {

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to