This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 5caaf329c71f [SPARK-49836][SQL][SS] Fix possibly broken query when 
window is provided to window/session_window fn
5caaf329c71f is described below

commit 5caaf329c71f15c6c57d7998054b72447ae70308
Author: Jungtaek Lim <[email protected]>
AuthorDate: Sat Oct 5 07:38:42 2024 +0900

    [SPARK-49836][SQL][SS] Fix possibly broken query when window is provided to 
window/session_window fn
    
    This PR fixes the correctness issue about losing operators during analysis 
- it happens when window is provided to window()/session_window() function.
    
    The rule `TimeWindowing` and `SessionWindowing` are responsible to resolve 
the time window functions. When the window function has `window` as parameter 
(time column) (in other words, building time window from time window), the rule 
wraps window with WindowTime function so that the rule ResolveWindowTime will 
further resolve this. (And TimeWindowing/SessionWindowing will resolve this 
again against the result of ResolveWindowTime.)
    
    The issue is that the rule uses "return" for the above, which intends to 
have "early return" as the other branch is too long compared to this branch. 
This unfortunately does not work as intended - the intention is just to go out 
of current local scope (mostly end of curly brace), but it seems to break the 
loop of execution in "outer" side.
    (I haven't debugged further but it's simply clear that it doesn't work as 
intended.)
    
    Quoting from Scala doc:
    
    > Nonlocal returns are implemented by throwing and catching 
scala.runtime.NonLocalReturnException-s.
    
    It's not super clear where NonLocalReturnException is caught in the call 
stack; it might exit the execution for much broader scope (context) than 
expected. And it's finally deprecated in Scala 3.2 and likely be removed in 
future.
    
    https://dotty.epfl.ch/docs/reference/dropped-features/nonlocal-returns.html
    
    Interestingly it does not break every query for chained time window 
aggregations. Spark already has several tests with DataFrame API and they 
haven't failed. The reproducer in community report is using SQL statement - 
where each aggregation is considered as subquery.
    
    This PR fixes the rule to NOT use early return and instead have a huge if 
else.
    
    Described in above.
    
    Yes, this fixes the possible query breakage. The impacted workloads may not 
be very huge as chained time window aggregations is an advanced usage, and it 
does not break every query for the usage.
    
    New UTs.
    
    No.
    
    Closes #48309 from HeartSaVioR/SPARK-49836.
    
    Lead-authored-by: Jungtaek Lim <[email protected]>
    Co-authored-by: Andrzej Zera <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
    (cherry picked from commit d8c04cf2fb7599c993948df10f4746b70f8c52b9)
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 .../sql/catalyst/analysis/ResolveTimeWindows.scala | 248 ++++++++++-----------
 .../spark/sql/DataFrameSessionWindowingSuite.scala |  51 +++++
 .../spark/sql/DataFrameTimeWindowingSuite.scala    |  53 +++++
 3 files changed, 228 insertions(+), 124 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala
index 5ce6a531cf09..73b313434857 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala
@@ -87,84 +87,84 @@ object TimeWindowing extends Rule[LogicalPlan] {
         val window = windowExpressions.head
 
         if (StructType.acceptsType(window.timeColumn.dataType)) {
-          return p.transformExpressions {
+          p.transformExpressions {
             case t: TimeWindow => t.copy(timeColumn = 
WindowTime(window.timeColumn))
           }
-        }
-
-        val metadata = window.timeColumn match {
-          case a: Attribute => a.metadata
-          case _ => Metadata.empty
-        }
-
-        val newMetadata = new MetadataBuilder()
-          .withMetadata(metadata)
-          .putBoolean(TimeWindow.marker, true)
-          .build()
+        } else {
+          val metadata = window.timeColumn match {
+            case a: Attribute => a.metadata
+            case _ => Metadata.empty
+          }
 
-        def getWindow(i: Int, dataType: DataType): Expression = {
-          val timestamp = PreciseTimestampConversion(window.timeColumn, 
dataType, LongType)
-          val remainder = (timestamp - window.startTime) % window.slideDuration
-          val lastStart = timestamp - CaseWhen(Seq((LessThan(remainder, 0),
-            remainder + window.slideDuration)), Some(remainder))
-          val windowStart = lastStart - i * window.slideDuration
-          val windowEnd = windowStart + window.windowDuration
+          val newMetadata = new MetadataBuilder()
+            .withMetadata(metadata)
+            .putBoolean(TimeWindow.marker, true)
+            .build()
 
-          // We make sure value fields are nullable since the dataType of 
TimeWindow defines them
-          // as nullable.
-          CreateNamedStruct(
-            Literal(WINDOW_START) ::
-              PreciseTimestampConversion(windowStart, LongType, 
dataType).castNullable() ::
-              Literal(WINDOW_END) ::
-              PreciseTimestampConversion(windowEnd, LongType, 
dataType).castNullable() ::
-              Nil)
-        }
+          def getWindow(i: Int, dataType: DataType): Expression = {
+            val timestamp = PreciseTimestampConversion(window.timeColumn, 
dataType, LongType)
+            val remainder = (timestamp - window.startTime) % 
window.slideDuration
+            val lastStart = timestamp - CaseWhen(Seq((LessThan(remainder, 0),
+              remainder + window.slideDuration)), Some(remainder))
+            val windowStart = lastStart - i * window.slideDuration
+            val windowEnd = windowStart + window.windowDuration
+
+            // We make sure value fields are nullable since the dataType of 
TimeWindow defines them
+            // as nullable.
+            CreateNamedStruct(
+              Literal(WINDOW_START) ::
+                PreciseTimestampConversion(windowStart, LongType, 
dataType).castNullable() ::
+                Literal(WINDOW_END) ::
+                PreciseTimestampConversion(windowEnd, LongType, 
dataType).castNullable() ::
+                Nil)
+          }
 
-        val windowAttr = AttributeReference(
-          WINDOW_COL_NAME, window.dataType, metadata = newMetadata)()
+          val windowAttr = AttributeReference(
+            WINDOW_COL_NAME, window.dataType, metadata = newMetadata)()
 
-        if (window.windowDuration == window.slideDuration) {
-          val windowStruct = Alias(getWindow(0, window.timeColumn.dataType), 
WINDOW_COL_NAME)(
-            exprId = windowAttr.exprId, explicitMetadata = Some(newMetadata))
+          if (window.windowDuration == window.slideDuration) {
+            val windowStruct = Alias(getWindow(0, window.timeColumn.dataType), 
WINDOW_COL_NAME)(
+              exprId = windowAttr.exprId, explicitMetadata = Some(newMetadata))
 
-          val replacedPlan = p transformExpressions {
-            case t: TimeWindow => windowAttr
-          }
+            val replacedPlan = p transformExpressions {
+              case t: TimeWindow => windowAttr
+            }
 
-          // For backwards compatibility we add a filter to filter out nulls
-          val filterExpr = IsNotNull(window.timeColumn)
+            // For backwards compatibility we add a filter to filter out nulls
+            val filterExpr = IsNotNull(window.timeColumn)
 
-          replacedPlan.withNewChildren(
-            Project(windowStruct +: child.output,
-              Filter(filterExpr, child)) :: Nil)
-        } else {
-          val overlappingWindows =
-            math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt
-          val windows =
-            Seq.tabulate(overlappingWindows)(i =>
-              getWindow(i, window.timeColumn.dataType))
-
-          val projections = windows.map(_ +: child.output)
-
-          // When the condition windowDuration % slideDuration = 0 is 
fulfilled,
-          // the estimation of the number of windows becomes exact one,
-          // which means all produced windows are valid.
-          val filterExpr =
-          if (window.windowDuration % window.slideDuration == 0) {
-            IsNotNull(window.timeColumn)
+            replacedPlan.withNewChildren(
+              Project(windowStruct +: child.output,
+                Filter(filterExpr, child)) :: Nil)
           } else {
-            window.timeColumn >= windowAttr.getField(WINDOW_START) &&
-              window.timeColumn < windowAttr.getField(WINDOW_END)
+            val overlappingWindows =
+              math.ceil(window.windowDuration * 1.0 / 
window.slideDuration).toInt
+            val windows =
+              Seq.tabulate(overlappingWindows)(i =>
+                getWindow(i, window.timeColumn.dataType))
+
+            val projections = windows.map(_ +: child.output)
+
+            // When the condition windowDuration % slideDuration = 0 is 
fulfilled,
+            // the estimation of the number of windows becomes exact one,
+            // which means all produced windows are valid.
+            val filterExpr =
+            if (window.windowDuration % window.slideDuration == 0) {
+              IsNotNull(window.timeColumn)
+            } else {
+              window.timeColumn >= windowAttr.getField(WINDOW_START) &&
+                window.timeColumn < windowAttr.getField(WINDOW_END)
+            }
+
+            val substitutedPlan = Filter(filterExpr,
+              Expand(projections, windowAttr +: child.output, child))
+
+            val renamedPlan = p transformExpressions {
+              case t: TimeWindow => windowAttr
+            }
+
+            renamedPlan.withNewChildren(substitutedPlan :: Nil)
           }
-
-          val substitutedPlan = Filter(filterExpr,
-            Expand(projections, windowAttr +: child.output, child))
-
-          val renamedPlan = p transformExpressions {
-            case t: TimeWindow => windowAttr
-          }
-
-          renamedPlan.withNewChildren(substitutedPlan :: Nil)
         }
       } else if (numWindowExpr > 1) {
         throw 
QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p)
@@ -209,71 +209,71 @@ object SessionWindowing extends Rule[LogicalPlan] {
         val session = sessionExpressions.head
 
         if (StructType.acceptsType(session.timeColumn.dataType)) {
-          return p transformExpressions {
+          p transformExpressions {
             case t: SessionWindow => t.copy(timeColumn = 
WindowTime(session.timeColumn))
           }
-        }
+        } else {
+          val metadata = session.timeColumn match {
+            case a: Attribute => a.metadata
+            case _ => Metadata.empty
+          }
 
-        val metadata = session.timeColumn match {
-          case a: Attribute => a.metadata
-          case _ => Metadata.empty
-        }
+          val newMetadata = new MetadataBuilder()
+            .withMetadata(metadata)
+            .putBoolean(SessionWindow.marker, true)
+            .build()
 
-        val newMetadata = new MetadataBuilder()
-          .withMetadata(metadata)
-          .putBoolean(SessionWindow.marker, true)
-          .build()
-
-        val sessionAttr = AttributeReference(
-          SESSION_COL_NAME, session.dataType, metadata = newMetadata)()
-
-        val sessionStart =
-          PreciseTimestampConversion(session.timeColumn, 
session.timeColumn.dataType, LongType)
-        val gapDuration = session.gapDuration match {
-          case expr if Cast.canCast(expr.dataType, CalendarIntervalType) =>
-            Cast(expr, CalendarIntervalType)
-          case other =>
-            throw 
QueryCompilationErrors.sessionWindowGapDurationDataTypeError(other.dataType)
-        }
-        val sessionEnd = PreciseTimestampConversion(session.timeColumn + 
gapDuration,
-          session.timeColumn.dataType, LongType)
-
-        // We make sure value fields are nullable since the dataType of 
SessionWindow defines them
-        // as nullable.
-        val literalSessionStruct = CreateNamedStruct(
-          Literal(SESSION_START) ::
-            PreciseTimestampConversion(sessionStart, LongType, 
session.timeColumn.dataType)
-              .castNullable() ::
-            Literal(SESSION_END) ::
-            PreciseTimestampConversion(sessionEnd, LongType, 
session.timeColumn.dataType)
-              .castNullable() ::
-            Nil)
-
-        val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)(
-          exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata))
+          val sessionAttr = AttributeReference(
+            SESSION_COL_NAME, session.dataType, metadata = newMetadata)()
 
-        val replacedPlan = p transformExpressions {
-          case s: SessionWindow => sessionAttr
-        }
+          val sessionStart =
+            PreciseTimestampConversion(session.timeColumn, 
session.timeColumn.dataType, LongType)
+          val gapDuration = session.gapDuration match {
+            case expr if Cast.canCast(expr.dataType, CalendarIntervalType) =>
+              Cast(expr, CalendarIntervalType)
+            case other =>
+              throw 
QueryCompilationErrors.sessionWindowGapDurationDataTypeError(other.dataType)
+          }
+          val sessionEnd = PreciseTimestampConversion(session.timeColumn + 
gapDuration,
+            session.timeColumn.dataType, LongType)
 
-        val filterByTimeRange = session.gapDuration match {
-          case Literal(interval: CalendarInterval, CalendarIntervalType) =>
-            interval == null || interval.months + interval.days + 
interval.microseconds <= 0
-          case _ => true
-        }
+          // We make sure value fields are nullable since the dataType of 
SessionWindow defines them
+          // as nullable.
+          val literalSessionStruct = CreateNamedStruct(
+            Literal(SESSION_START) ::
+              PreciseTimestampConversion(sessionStart, LongType, 
session.timeColumn.dataType)
+                .castNullable() ::
+              Literal(SESSION_END) ::
+              PreciseTimestampConversion(sessionEnd, LongType, 
session.timeColumn.dataType)
+                .castNullable() ::
+              Nil)
 
-        // As same as tumbling window, we add a filter to filter out nulls.
-        // And we also filter out events with negative or zero or invalid gap 
duration.
-        val filterExpr = if (filterByTimeRange) {
-          IsNotNull(session.timeColumn) &&
-            (sessionAttr.getField(SESSION_END) > 
sessionAttr.getField(SESSION_START))
-        } else {
-          IsNotNull(session.timeColumn)
-        }
+          val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)(
+            exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata))
 
-        replacedPlan.withNewChildren(
-          Filter(filterExpr,
-            Project(sessionStruct +: child.output, child)) :: Nil)
+          val replacedPlan = p transformExpressions {
+            case s: SessionWindow => sessionAttr
+          }
+
+          val filterByTimeRange = session.gapDuration match {
+            case Literal(interval: CalendarInterval, CalendarIntervalType) =>
+              interval == null || interval.months + interval.days + 
interval.microseconds <= 0
+            case _ => true
+          }
+
+          // As same as tumbling window, we add a filter to filter out nulls.
+          // And we also filter out events with negative or zero or invalid 
gap duration.
+          val filterExpr = if (filterByTimeRange) {
+            IsNotNull(session.timeColumn) &&
+              (sessionAttr.getField(SESSION_END) > 
sessionAttr.getField(SESSION_START))
+          } else {
+            IsNotNull(session.timeColumn)
+          }
+
+          replacedPlan.withNewChildren(
+            Filter(filterExpr,
+              Project(sessionStruct +: child.output, child)) :: Nil)
+        }
       } else if (numWindowExpr > 1) {
         throw 
QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p)
       } else {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala
index c98806822709..454549636e8f 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala
@@ -547,4 +547,55 @@ class DataFrameSessionWindowingSuite extends QueryTest 
with SharedSparkSession
       }
     }
   }
+
+  test("SPARK-49836 using window fn with window as parameter should preserve 
parent operator") {
+    withTempView("clicks") {
+      val df = Seq(
+        // small window: [00:00, 01:00), user1, 2
+        ("2024-09-30 00:00:00", "user1"), ("2024-09-30 00:00:30", "user1"),
+        // small window: [01:00, 02:00), user2, 2
+        ("2024-09-30 00:01:00", "user2"), ("2024-09-30 00:01:30", "user2"),
+        // small window: [03:00, 04:00), user1, 1
+        ("2024-09-30 00:03:30", "user1"),
+        // small window: [11:00, 12:00), user1, 3
+        ("2024-09-30 00:11:00", "user1"), ("2024-09-30 00:11:30", "user1"),
+        ("2024-09-30 00:11:45", "user1")
+      ).toDF("eventTime", "userId")
+
+      // session window: (01:00, 09:00), user1, 3 / (02:00, 07:00), user2, 2 /
+      //   (12:00, 12:05), user1, 3
+
+      df.createOrReplaceTempView("clicks")
+
+      val aggregatedData = spark.sql(
+        """
+          |SELECT
+          |  userId,
+          |  avg(cpu_large.numClicks) AS clicksPerSession
+          |FROM
+          |(
+          |  SELECT
+          |    session_window(small_window, '5 minutes') AS session,
+          |    userId,
+          |    sum(numClicks) AS numClicks
+          |  FROM
+          |  (
+          |    SELECT
+          |      window(eventTime, '1 minute') AS small_window,
+          |      userId,
+          |      count(*) AS numClicks
+          |    FROM clicks
+          |    GROUP BY window, userId
+          |  ) cpu_small
+          |  GROUP BY session_window, userId
+          |) cpu_large
+          |GROUP BY userId
+          |""".stripMargin)
+
+      checkAnswer(
+        aggregatedData,
+        Seq(Row("user1", 3), Row("user2", 2))
+      )
+    }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
index 367cdbe84472..cf656284f6aa 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql
 
+import java.sql.Timestamp
 import java.time.LocalDateTime
 
 import org.apache.spark.sql.catalyst.encoders.RowEncoder
@@ -714,4 +715,56 @@ class DataFrameTimeWindowingSuite extends QueryTest with 
SharedSparkSession {
       )
     }
   }
+
+  test("SPARK-49836 using window fn with window as parameter should preserve 
parent operator") {
+    withTempView("clicks") {
+      val df = Seq(
+        // small window: [00:00, 01:00), user1, 2
+        ("2024-09-30 00:00:00", "user1"), ("2024-09-30 00:00:30", "user1"),
+        // small window: [01:00, 02:00), user2, 2
+        ("2024-09-30 00:01:00", "user2"), ("2024-09-30 00:01:30", "user2"),
+        // small window: [07:00, 08:00), user1, 1
+        ("2024-09-30 00:07:00", "user1"),
+        // small window: [11:00, 12:00), user1, 3
+        ("2024-09-30 00:11:00", "user1"), ("2024-09-30 00:11:30", "user1"),
+        ("2024-09-30 00:11:45", "user1")
+      ).toDF("eventTime", "userId")
+
+      // large window: [00:00, 10:00), user1, 3, [00:00, 10:00), user2, 2, 
[10:00, 20:00), user1, 3
+
+      df.createOrReplaceTempView("clicks")
+
+      val aggregatedData = spark.sql(
+        """
+          |SELECT
+          |  cpu_large.large_window.end AS timestamp,
+          |  avg(cpu_large.numClicks) AS avgClicksPerUser
+          |FROM
+          |(
+          |  SELECT
+          |    window(small_window, '10 minutes') AS large_window,
+          |    userId,
+          |    sum(numClicks) AS numClicks
+          |  FROM
+          |  (
+          |    SELECT
+          |      window(eventTime, '1 minute') AS small_window,
+          |      userId,
+          |      count(*) AS numClicks
+          |    FROM clicks
+          |    GROUP BY window, userId
+          |  ) cpu_small
+          |  GROUP BY window, userId
+          |) cpu_large
+          |GROUP BY timestamp
+          |""".stripMargin)
+
+      checkAnswer(
+        aggregatedData,
+        Seq(
+          Row(Timestamp.valueOf("2024-09-30 00:10:00"), 2.5),
+          Row(Timestamp.valueOf("2024-09-30 00:20:00"), 3))
+      )
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to