This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new 5caaf329c71f [SPARK-49836][SQL][SS] Fix possibly broken query when
window is provided to window/session_window fn
5caaf329c71f is described below
commit 5caaf329c71f15c6c57d7998054b72447ae70308
Author: Jungtaek Lim <[email protected]>
AuthorDate: Sat Oct 5 07:38:42 2024 +0900
[SPARK-49836][SQL][SS] Fix possibly broken query when window is provided to
window/session_window fn
This PR fixes the correctness issue about losing operators during analysis
- it happens when window is provided to window()/session_window() function.
The rule `TimeWindowing` and `SessionWindowing` are responsible to resolve
the time window functions. When the window function has `window` as parameter
(time column) (in other words, building time window from time window), the rule
wraps window with WindowTime function so that the rule ResolveWindowTime will
further resolve this. (And TimeWindowing/SessionWindowing will resolve this
again against the result of ResolveWindowTime.)
The issue is that the rule uses "return" for the above, which intends to
have "early return" as the other branch is too long compared to this branch.
This unfortunately does not work as intended - the intention is just to go out
of current local scope (mostly end of curly brace), but it seems to break the
loop of execution in "outer" side.
(I haven't debugged further but it's simply clear that it doesn't work as
intended.)
Quoting from Scala doc:
> Nonlocal returns are implemented by throwing and catching
scala.runtime.NonLocalReturnException-s.
It's not super clear where NonLocalReturnException is caught in the call
stack; it might exit the execution for much broader scope (context) than
expected. And it's finally deprecated in Scala 3.2 and likely be removed in
future.
https://dotty.epfl.ch/docs/reference/dropped-features/nonlocal-returns.html
Interestingly it does not break every query for chained time window
aggregations. Spark already has several tests with DataFrame API and they
haven't failed. The reproducer in community report is using SQL statement -
where each aggregation is considered as subquery.
This PR fixes the rule to NOT use early return and instead have a huge if
else.
Described in above.
Yes, this fixes the possible query breakage. The impacted workloads may not
be very huge as chained time window aggregations is an advanced usage, and it
does not break every query for the usage.
New UTs.
No.
Closes #48309 from HeartSaVioR/SPARK-49836.
Lead-authored-by: Jungtaek Lim <[email protected]>
Co-authored-by: Andrzej Zera <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
(cherry picked from commit d8c04cf2fb7599c993948df10f4746b70f8c52b9)
Signed-off-by: Jungtaek Lim <[email protected]>
---
.../sql/catalyst/analysis/ResolveTimeWindows.scala | 248 ++++++++++-----------
.../spark/sql/DataFrameSessionWindowingSuite.scala | 51 +++++
.../spark/sql/DataFrameTimeWindowingSuite.scala | 53 +++++
3 files changed, 228 insertions(+), 124 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala
index 5ce6a531cf09..73b313434857 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala
@@ -87,84 +87,84 @@ object TimeWindowing extends Rule[LogicalPlan] {
val window = windowExpressions.head
if (StructType.acceptsType(window.timeColumn.dataType)) {
- return p.transformExpressions {
+ p.transformExpressions {
case t: TimeWindow => t.copy(timeColumn =
WindowTime(window.timeColumn))
}
- }
-
- val metadata = window.timeColumn match {
- case a: Attribute => a.metadata
- case _ => Metadata.empty
- }
-
- val newMetadata = new MetadataBuilder()
- .withMetadata(metadata)
- .putBoolean(TimeWindow.marker, true)
- .build()
+ } else {
+ val metadata = window.timeColumn match {
+ case a: Attribute => a.metadata
+ case _ => Metadata.empty
+ }
- def getWindow(i: Int, dataType: DataType): Expression = {
- val timestamp = PreciseTimestampConversion(window.timeColumn,
dataType, LongType)
- val remainder = (timestamp - window.startTime) % window.slideDuration
- val lastStart = timestamp - CaseWhen(Seq((LessThan(remainder, 0),
- remainder + window.slideDuration)), Some(remainder))
- val windowStart = lastStart - i * window.slideDuration
- val windowEnd = windowStart + window.windowDuration
+ val newMetadata = new MetadataBuilder()
+ .withMetadata(metadata)
+ .putBoolean(TimeWindow.marker, true)
+ .build()
- // We make sure value fields are nullable since the dataType of
TimeWindow defines them
- // as nullable.
- CreateNamedStruct(
- Literal(WINDOW_START) ::
- PreciseTimestampConversion(windowStart, LongType,
dataType).castNullable() ::
- Literal(WINDOW_END) ::
- PreciseTimestampConversion(windowEnd, LongType,
dataType).castNullable() ::
- Nil)
- }
+ def getWindow(i: Int, dataType: DataType): Expression = {
+ val timestamp = PreciseTimestampConversion(window.timeColumn,
dataType, LongType)
+ val remainder = (timestamp - window.startTime) %
window.slideDuration
+ val lastStart = timestamp - CaseWhen(Seq((LessThan(remainder, 0),
+ remainder + window.slideDuration)), Some(remainder))
+ val windowStart = lastStart - i * window.slideDuration
+ val windowEnd = windowStart + window.windowDuration
+
+ // We make sure value fields are nullable since the dataType of
TimeWindow defines them
+ // as nullable.
+ CreateNamedStruct(
+ Literal(WINDOW_START) ::
+ PreciseTimestampConversion(windowStart, LongType,
dataType).castNullable() ::
+ Literal(WINDOW_END) ::
+ PreciseTimestampConversion(windowEnd, LongType,
dataType).castNullable() ::
+ Nil)
+ }
- val windowAttr = AttributeReference(
- WINDOW_COL_NAME, window.dataType, metadata = newMetadata)()
+ val windowAttr = AttributeReference(
+ WINDOW_COL_NAME, window.dataType, metadata = newMetadata)()
- if (window.windowDuration == window.slideDuration) {
- val windowStruct = Alias(getWindow(0, window.timeColumn.dataType),
WINDOW_COL_NAME)(
- exprId = windowAttr.exprId, explicitMetadata = Some(newMetadata))
+ if (window.windowDuration == window.slideDuration) {
+ val windowStruct = Alias(getWindow(0, window.timeColumn.dataType),
WINDOW_COL_NAME)(
+ exprId = windowAttr.exprId, explicitMetadata = Some(newMetadata))
- val replacedPlan = p transformExpressions {
- case t: TimeWindow => windowAttr
- }
+ val replacedPlan = p transformExpressions {
+ case t: TimeWindow => windowAttr
+ }
- // For backwards compatibility we add a filter to filter out nulls
- val filterExpr = IsNotNull(window.timeColumn)
+ // For backwards compatibility we add a filter to filter out nulls
+ val filterExpr = IsNotNull(window.timeColumn)
- replacedPlan.withNewChildren(
- Project(windowStruct +: child.output,
- Filter(filterExpr, child)) :: Nil)
- } else {
- val overlappingWindows =
- math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt
- val windows =
- Seq.tabulate(overlappingWindows)(i =>
- getWindow(i, window.timeColumn.dataType))
-
- val projections = windows.map(_ +: child.output)
-
- // When the condition windowDuration % slideDuration = 0 is
fulfilled,
- // the estimation of the number of windows becomes exact one,
- // which means all produced windows are valid.
- val filterExpr =
- if (window.windowDuration % window.slideDuration == 0) {
- IsNotNull(window.timeColumn)
+ replacedPlan.withNewChildren(
+ Project(windowStruct +: child.output,
+ Filter(filterExpr, child)) :: Nil)
} else {
- window.timeColumn >= windowAttr.getField(WINDOW_START) &&
- window.timeColumn < windowAttr.getField(WINDOW_END)
+ val overlappingWindows =
+ math.ceil(window.windowDuration * 1.0 /
window.slideDuration).toInt
+ val windows =
+ Seq.tabulate(overlappingWindows)(i =>
+ getWindow(i, window.timeColumn.dataType))
+
+ val projections = windows.map(_ +: child.output)
+
+ // When the condition windowDuration % slideDuration = 0 is
fulfilled,
+ // the estimation of the number of windows becomes exact one,
+ // which means all produced windows are valid.
+ val filterExpr =
+ if (window.windowDuration % window.slideDuration == 0) {
+ IsNotNull(window.timeColumn)
+ } else {
+ window.timeColumn >= windowAttr.getField(WINDOW_START) &&
+ window.timeColumn < windowAttr.getField(WINDOW_END)
+ }
+
+ val substitutedPlan = Filter(filterExpr,
+ Expand(projections, windowAttr +: child.output, child))
+
+ val renamedPlan = p transformExpressions {
+ case t: TimeWindow => windowAttr
+ }
+
+ renamedPlan.withNewChildren(substitutedPlan :: Nil)
}
-
- val substitutedPlan = Filter(filterExpr,
- Expand(projections, windowAttr +: child.output, child))
-
- val renamedPlan = p transformExpressions {
- case t: TimeWindow => windowAttr
- }
-
- renamedPlan.withNewChildren(substitutedPlan :: Nil)
}
} else if (numWindowExpr > 1) {
throw
QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p)
@@ -209,71 +209,71 @@ object SessionWindowing extends Rule[LogicalPlan] {
val session = sessionExpressions.head
if (StructType.acceptsType(session.timeColumn.dataType)) {
- return p transformExpressions {
+ p transformExpressions {
case t: SessionWindow => t.copy(timeColumn =
WindowTime(session.timeColumn))
}
- }
+ } else {
+ val metadata = session.timeColumn match {
+ case a: Attribute => a.metadata
+ case _ => Metadata.empty
+ }
- val metadata = session.timeColumn match {
- case a: Attribute => a.metadata
- case _ => Metadata.empty
- }
+ val newMetadata = new MetadataBuilder()
+ .withMetadata(metadata)
+ .putBoolean(SessionWindow.marker, true)
+ .build()
- val newMetadata = new MetadataBuilder()
- .withMetadata(metadata)
- .putBoolean(SessionWindow.marker, true)
- .build()
-
- val sessionAttr = AttributeReference(
- SESSION_COL_NAME, session.dataType, metadata = newMetadata)()
-
- val sessionStart =
- PreciseTimestampConversion(session.timeColumn,
session.timeColumn.dataType, LongType)
- val gapDuration = session.gapDuration match {
- case expr if Cast.canCast(expr.dataType, CalendarIntervalType) =>
- Cast(expr, CalendarIntervalType)
- case other =>
- throw
QueryCompilationErrors.sessionWindowGapDurationDataTypeError(other.dataType)
- }
- val sessionEnd = PreciseTimestampConversion(session.timeColumn +
gapDuration,
- session.timeColumn.dataType, LongType)
-
- // We make sure value fields are nullable since the dataType of
SessionWindow defines them
- // as nullable.
- val literalSessionStruct = CreateNamedStruct(
- Literal(SESSION_START) ::
- PreciseTimestampConversion(sessionStart, LongType,
session.timeColumn.dataType)
- .castNullable() ::
- Literal(SESSION_END) ::
- PreciseTimestampConversion(sessionEnd, LongType,
session.timeColumn.dataType)
- .castNullable() ::
- Nil)
-
- val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)(
- exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata))
+ val sessionAttr = AttributeReference(
+ SESSION_COL_NAME, session.dataType, metadata = newMetadata)()
- val replacedPlan = p transformExpressions {
- case s: SessionWindow => sessionAttr
- }
+ val sessionStart =
+ PreciseTimestampConversion(session.timeColumn,
session.timeColumn.dataType, LongType)
+ val gapDuration = session.gapDuration match {
+ case expr if Cast.canCast(expr.dataType, CalendarIntervalType) =>
+ Cast(expr, CalendarIntervalType)
+ case other =>
+ throw
QueryCompilationErrors.sessionWindowGapDurationDataTypeError(other.dataType)
+ }
+ val sessionEnd = PreciseTimestampConversion(session.timeColumn +
gapDuration,
+ session.timeColumn.dataType, LongType)
- val filterByTimeRange = session.gapDuration match {
- case Literal(interval: CalendarInterval, CalendarIntervalType) =>
- interval == null || interval.months + interval.days +
interval.microseconds <= 0
- case _ => true
- }
+ // We make sure value fields are nullable since the dataType of
SessionWindow defines them
+ // as nullable.
+ val literalSessionStruct = CreateNamedStruct(
+ Literal(SESSION_START) ::
+ PreciseTimestampConversion(sessionStart, LongType,
session.timeColumn.dataType)
+ .castNullable() ::
+ Literal(SESSION_END) ::
+ PreciseTimestampConversion(sessionEnd, LongType,
session.timeColumn.dataType)
+ .castNullable() ::
+ Nil)
- // As same as tumbling window, we add a filter to filter out nulls.
- // And we also filter out events with negative or zero or invalid gap
duration.
- val filterExpr = if (filterByTimeRange) {
- IsNotNull(session.timeColumn) &&
- (sessionAttr.getField(SESSION_END) >
sessionAttr.getField(SESSION_START))
- } else {
- IsNotNull(session.timeColumn)
- }
+ val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)(
+ exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata))
- replacedPlan.withNewChildren(
- Filter(filterExpr,
- Project(sessionStruct +: child.output, child)) :: Nil)
+ val replacedPlan = p transformExpressions {
+ case s: SessionWindow => sessionAttr
+ }
+
+ val filterByTimeRange = session.gapDuration match {
+ case Literal(interval: CalendarInterval, CalendarIntervalType) =>
+ interval == null || interval.months + interval.days +
interval.microseconds <= 0
+ case _ => true
+ }
+
+ // As same as tumbling window, we add a filter to filter out nulls.
+ // And we also filter out events with negative or zero or invalid
gap duration.
+ val filterExpr = if (filterByTimeRange) {
+ IsNotNull(session.timeColumn) &&
+ (sessionAttr.getField(SESSION_END) >
sessionAttr.getField(SESSION_START))
+ } else {
+ IsNotNull(session.timeColumn)
+ }
+
+ replacedPlan.withNewChildren(
+ Filter(filterExpr,
+ Project(sessionStruct +: child.output, child)) :: Nil)
+ }
} else if (numWindowExpr > 1) {
throw
QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p)
} else {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala
index c98806822709..454549636e8f 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala
@@ -547,4 +547,55 @@ class DataFrameSessionWindowingSuite extends QueryTest
with SharedSparkSession
}
}
}
+
+ test("SPARK-49836 using window fn with window as parameter should preserve
parent operator") {
+ withTempView("clicks") {
+ val df = Seq(
+ // small window: [00:00, 01:00), user1, 2
+ ("2024-09-30 00:00:00", "user1"), ("2024-09-30 00:00:30", "user1"),
+ // small window: [01:00, 02:00), user2, 2
+ ("2024-09-30 00:01:00", "user2"), ("2024-09-30 00:01:30", "user2"),
+ // small window: [03:00, 04:00), user1, 1
+ ("2024-09-30 00:03:30", "user1"),
+ // small window: [11:00, 12:00), user1, 3
+ ("2024-09-30 00:11:00", "user1"), ("2024-09-30 00:11:30", "user1"),
+ ("2024-09-30 00:11:45", "user1")
+ ).toDF("eventTime", "userId")
+
+ // session window: (01:00, 09:00), user1, 3 / (02:00, 07:00), user2, 2 /
+ // (12:00, 12:05), user1, 3
+
+ df.createOrReplaceTempView("clicks")
+
+ val aggregatedData = spark.sql(
+ """
+ |SELECT
+ | userId,
+ | avg(cpu_large.numClicks) AS clicksPerSession
+ |FROM
+ |(
+ | SELECT
+ | session_window(small_window, '5 minutes') AS session,
+ | userId,
+ | sum(numClicks) AS numClicks
+ | FROM
+ | (
+ | SELECT
+ | window(eventTime, '1 minute') AS small_window,
+ | userId,
+ | count(*) AS numClicks
+ | FROM clicks
+ | GROUP BY window, userId
+ | ) cpu_small
+ | GROUP BY session_window, userId
+ |) cpu_large
+ |GROUP BY userId
+ |""".stripMargin)
+
+ checkAnswer(
+ aggregatedData,
+ Seq(Row("user1", 3), Row("user2", 2))
+ )
+ }
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
index 367cdbe84472..cf656284f6aa 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
+import java.sql.Timestamp
import java.time.LocalDateTime
import org.apache.spark.sql.catalyst.encoders.RowEncoder
@@ -714,4 +715,56 @@ class DataFrameTimeWindowingSuite extends QueryTest with
SharedSparkSession {
)
}
}
+
+ test("SPARK-49836 using window fn with window as parameter should preserve
parent operator") {
+ withTempView("clicks") {
+ val df = Seq(
+ // small window: [00:00, 01:00), user1, 2
+ ("2024-09-30 00:00:00", "user1"), ("2024-09-30 00:00:30", "user1"),
+ // small window: [01:00, 02:00), user2, 2
+ ("2024-09-30 00:01:00", "user2"), ("2024-09-30 00:01:30", "user2"),
+ // small window: [07:00, 08:00), user1, 1
+ ("2024-09-30 00:07:00", "user1"),
+ // small window: [11:00, 12:00), user1, 3
+ ("2024-09-30 00:11:00", "user1"), ("2024-09-30 00:11:30", "user1"),
+ ("2024-09-30 00:11:45", "user1")
+ ).toDF("eventTime", "userId")
+
+ // large window: [00:00, 10:00), user1, 3, [00:00, 10:00), user2, 2,
[10:00, 20:00), user1, 3
+
+ df.createOrReplaceTempView("clicks")
+
+ val aggregatedData = spark.sql(
+ """
+ |SELECT
+ | cpu_large.large_window.end AS timestamp,
+ | avg(cpu_large.numClicks) AS avgClicksPerUser
+ |FROM
+ |(
+ | SELECT
+ | window(small_window, '10 minutes') AS large_window,
+ | userId,
+ | sum(numClicks) AS numClicks
+ | FROM
+ | (
+ | SELECT
+ | window(eventTime, '1 minute') AS small_window,
+ | userId,
+ | count(*) AS numClicks
+ | FROM clicks
+ | GROUP BY window, userId
+ | ) cpu_small
+ | GROUP BY window, userId
+ |) cpu_large
+ |GROUP BY timestamp
+ |""".stripMargin)
+
+ checkAnswer(
+ aggregatedData,
+ Seq(
+ Row(Timestamp.valueOf("2024-09-30 00:10:00"), 2.5),
+ Row(Timestamp.valueOf("2024-09-30 00:20:00"), 3))
+ )
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]