This is an automated email from the ASF dual-hosted git repository.

beliefer pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6a197efeb3c [SPARK-45649][SQL] Unify the prepare framework for 
OffsetWindowFunctionFrame
6a197efeb3c is described below

commit 6a197efeb3c1cca156cd615e990e35e82ce22ee3
Author: Jiaan Geng <belie...@163.com>
AuthorDate: Mon Dec 11 19:48:14 2023 +0800

    [SPARK-45649][SQL] Unify the prepare framework for OffsetWindowFunctionFrame
    
    ### What changes were proposed in this pull request?
    Currently, the implementation of the `prepare` of all the 
`OffsetWindowFunctionFrame` have the same code logic show below.
    ```
      override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
        if (offset > rows.length) {
          fillDefaultValue(EmptyRow)
        } else {
          resetStates(rows)
          if (ignoreNulls) {
            ...
          } else {
            ...
          }
        }
      }
    ```
    This PR want unify the prepare framework for `OffsetWindowFunctionFrame`
    
    **Why the https://github.com/apache/spark/pull/43507 introduces the NPE 
bug?**
    For example, there is a window group with the offset 5 and have 4 elements.
    First, we don't call the `resetStates` due to the offset is greater than 4.
    After that, we iterates the elements of the window group by visit input. 
But the input is null.
    
    This PR also add two test cases about the absolute value of offset greater 
than the window group size.
    
    ### Why are the changes needed?
    Unify the prepare framework for `OffsetWindowFunctionFrame`
    
    ### Does this PR introduce _any_ user-facing change?
    'No'.
    Inner update.
    
    ### How was this patch tested?
    Exists test cases.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    'No'.
    
    Closes #43958 from beliefer/SPARK-45649.
    
    Authored-by: Jiaan Geng <belie...@163.com>
    Signed-off-by: Jiaan Geng <belie...@163.com>
---
 .../sql/execution/window/WindowFunctionFrame.scala | 114 ++++++++++-----------
 .../spark/sql/DataFrameWindowFramesSuite.scala     |  24 +++++
 2 files changed, 76 insertions(+), 62 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
index 6cea838311a..4aa7444c407 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
@@ -87,7 +87,8 @@ abstract class OffsetWindowFunctionFrameBase(
     expressions: Array[OffsetWindowFunction],
     inputSchema: Seq[Attribute],
     newMutableProjection: (Seq[Expression], Seq[Attribute]) => 
MutableProjection,
-    offset: Int)
+    offset: Int,
+    ignoreNulls: Boolean)
   extends WindowFunctionFrame {
 
   /** Rows of the partition currently being processed. */
@@ -141,6 +142,8 @@ abstract class OffsetWindowFunctionFrameBase(
   // is not null.
   protected var skippedNonNullCount = 0
 
+  protected val absOffset = Math.abs(offset)
+
   // Reset the states by the data of the new partition.
   protected def resetStates(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
     input = rows
@@ -176,6 +179,33 @@ abstract class OffsetWindowFunctionFrameBase(
     }
   }
 
+  override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
+    resetStates(rows)
+    if (absOffset > rows.length) {
+      fillDefaultValue(EmptyRow)
+    } else {
+      if (ignoreNulls) {
+        prepareForIgnoreNulls()
+      } else {
+        prepareForRespectNulls()
+      }
+    }
+  }
+
+  protected def prepareForIgnoreNulls(): Unit = findNextRowWithNonNullInput()
+
+  protected def prepareForRespectNulls(): Unit = {
+    // drain the first few rows if offset is larger than one
+    while (inputIndex < offset) {
+      nextSelectedRow = WindowFunctionFrame.getNextOrNull(inputIterator)
+      inputIndex += 1
+    }
+    // `inputIndex` starts as 0, but the `offset` can be negative and we may 
not enter the
+    // while loop at all. We need to make sure `inputIndex` ends up as 
`offset` to meet the
+    // assumption of the write path.
+    inputIndex = offset
+  }
+
   override def currentLowerBound(): Int = throw new 
UnsupportedOperationException()
 
   override def currentUpperBound(): Int = throw new 
UnsupportedOperationException()
@@ -197,25 +227,7 @@ class FrameLessOffsetWindowFunctionFrame(
     offset: Int,
     ignoreNulls: Boolean = false)
   extends OffsetWindowFunctionFrameBase(
-    target, ordinal, expressions, inputSchema, newMutableProjection, offset) {
-
-  override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
-    resetStates(rows)
-    if (ignoreNulls) {
-      if (Math.abs(offset) > rows.length) {
-        fillDefaultValue(EmptyRow)
-      } else {
-        findNextRowWithNonNullInput()
-      }
-    } else {
-      // drain the first few rows if offset is larger than zero
-      while (inputIndex < offset) {
-        if (inputIterator.hasNext) inputIterator.next()
-        inputIndex += 1
-      }
-      inputIndex = offset
-    }
-  }
+    target, ordinal, expressions, inputSchema, newMutableProjection, offset, 
ignoreNulls) {
 
   private val doWrite = if (ignoreNulls && offset > 0) {
     // For illustration, here is one example: the input data contains nine 
rows,
@@ -261,7 +273,6 @@ class FrameLessOffsetWindowFunctionFrame(
     // 7. current row -> z, next selected row -> y, output: y;
     // 8. current row -> v, next selected row -> z, output: z;
     // 9. current row -> null, next selected row -> v, output: v;
-    val absOffset = Math.abs(offset)
     (current: InternalRow) =>
       if (skippedNonNullCount == absOffset) {
         nextSelectedRow = EmptyRow
@@ -296,7 +307,11 @@ class FrameLessOffsetWindowFunctionFrame(
   }
 
   override def write(index: Int, current: InternalRow): Unit = {
-    doWrite(current)
+    if (absOffset > input.length) {
+      // Already use default values in prepare.
+    } else {
+      doWrite(current)
+    }
   }
 }
 
@@ -318,34 +333,24 @@ class UnboundedOffsetWindowFunctionFrame(
     offset: Int,
     ignoreNulls: Boolean = false)
   extends OffsetWindowFunctionFrameBase(
-    target, ordinal, expressions, inputSchema, newMutableProjection, offset) {
+    target, ordinal, expressions, inputSchema, newMutableProjection, offset, 
ignoreNulls) {
   assert(offset > 0)
 
-  override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
-    if (offset > rows.length) {
+  override def prepareForIgnoreNulls(): Unit = {
+    super.prepareForIgnoreNulls()
+    if (nextSelectedRow == EmptyRow) {
+      // Use default values since the offset row whose input value is not null 
does not exist.
       fillDefaultValue(EmptyRow)
     } else {
-      resetStates(rows)
-      if (ignoreNulls) {
-        findNextRowWithNonNullInput()
-        if (nextSelectedRow == EmptyRow) {
-          // Use default values since the offset row whose input value is not 
null does not exist.
-          fillDefaultValue(EmptyRow)
-        } else {
-          projection(nextSelectedRow)
-        }
-      } else {
-        var selectedRow: UnsafeRow = null
-        // drain the first few rows if offset is larger than one
-        while (inputIndex < offset) {
-          selectedRow = WindowFunctionFrame.getNextOrNull(inputIterator)
-          inputIndex += 1
-        }
-        projection(selectedRow)
-      }
+      projection(nextSelectedRow)
     }
   }
 
+  override def prepareForRespectNulls(): Unit = {
+    super.prepareForRespectNulls()
+    projection(nextSelectedRow)
+  }
+
   override def write(index: Int, current: InternalRow): Unit = {
     // The results are the same for each row in the partition, and have been 
evaluated in prepare.
     // Don't need to recalculate here.
@@ -371,28 +376,13 @@ class UnboundedPrecedingOffsetWindowFunctionFrame(
     offset: Int,
     ignoreNulls: Boolean = false)
   extends OffsetWindowFunctionFrameBase(
-    target, ordinal, expressions, inputSchema, newMutableProjection, offset) {
+    target, ordinal, expressions, inputSchema, newMutableProjection, offset, 
ignoreNulls) {
   assert(offset > 0)
 
-  override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
-    if (offset > rows.length) {
-      fillDefaultValue(EmptyRow)
-    } else {
-      resetStates(rows)
-      if (ignoreNulls) {
-        findNextRowWithNonNullInput()
-      } else {
-        // drain the first few rows if offset is larger than one
-        while (inputIndex < offset) {
-          nextSelectedRow = WindowFunctionFrame.getNextOrNull(inputIterator)
-          inputIndex += 1
-        }
-      }
-    }
-  }
-
   override def write(index: Int, current: InternalRow): Unit = {
-    if (index >= inputIndex - 1 && nextSelectedRow != null) {
+    if (absOffset > input.length) {
+      // Already use default values in prepare.
+    } else if (index >= inputIndex - 1 && nextSelectedRow != null) {
       projection(nextSelectedRow)
     } else {
       fillDefaultValue(EmptyRow)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala
index bb744cfd8ab..0e3932cf1e1 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala
@@ -65,6 +65,18 @@ class DataFrameWindowFramesSuite extends QueryTest with 
SharedSparkSession {
       Row(1, "1", null) :: Row(1, null, "3") :: Row(2, "2", null) :: Row(2, 
null, "4") :: Nil)
   }
 
+  test("lead/lag with positive offset that greater than window group size") {
+    val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value")
+    val window = Window.partitionBy($"key").orderBy($"value")
+
+    checkAnswer(
+      df.select(
+        $"key",
+        lead("value", 3).over(window),
+        lag("value", 3).over(window)),
+      Row(1, null, null) :: Row(1, null, null) :: Row(2, null, null) :: Row(2, 
null, null) :: Nil)
+  }
+
   test("lead/lag with negative offset") {
     val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value")
     val window = Window.partitionBy($"key").orderBy($"value")
@@ -77,6 +89,18 @@ class DataFrameWindowFramesSuite extends QueryTest with 
SharedSparkSession {
       Row(1, null, "3") :: Row(1, "1", null) :: Row(2, null, "4") :: Row(2, 
"2", null) :: Nil)
   }
 
+  test("lead/lag with negative offset that absolute value greater than window 
group size") {
+    val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value")
+    val window = Window.partitionBy($"key").orderBy($"value")
+
+    checkAnswer(
+      df.select(
+        $"key",
+        lead("value", -3).over(window),
+        lag("value", -3).over(window)),
+      Row(1, null, null) :: Row(1, null, null) :: Row(2, null, null) :: Row(2, 
null, null) :: Nil)
+  }
+
   test("reverse lead/lag with negative offset") {
     val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value")
     val window = Window.partitionBy($"key").orderBy($"value".desc)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to