This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 9a7b6e5c31bf [SPARK-49474][SS] Classify Error class for
FlatMapGroupsWithState user function error
9a7b6e5c31bf is described below
commit 9a7b6e5c31bfcca4283ed6bc22df10b743e9a470
Author: Livia Zhu <[email protected]>
AuthorDate: Thu Sep 5 16:24:01 2024 +0900
[SPARK-49474][SS] Classify Error class for FlatMapGroupsWithState user
function error
### What changes were proposed in this pull request?
Add new error classification for errors occurring in the user function that
is used in FlatMapGroupsWithState.
### Why are the changes needed?
The user provided function can throw any type of error. Using the new error
framework for better error messages and classification.
### Does this PR introduce _any_ user-facing change?
Yes, better error message with error class for Foreach sink user function
failures.
### How was this patch tested?
Updated existing tests and added new unit test in
FlatMapGroupsWithStateSuite.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #47940 from
liviazhu-db/liviazhu-db/classify-flatmapgroupswithstate-error.
Authored-by: Livia Zhu <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 6 ++
.../streaming/FlatMapGroupsWithStateExec.scala | 48 +++++++++++++--
.../streaming/FlatMapGroupsWithStateSuite.scala | 70 +++++++++++++++++++++-
3 files changed, 117 insertions(+), 7 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index e2725a98a63b..96105c967225 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -1457,6 +1457,12 @@
],
"sqlState" : "42704"
},
+ "FLATMAPGROUPSWITHSTATE_USER_FUNCTION_ERROR" : {
+ "message" : [
+ "An error occurred in the user provided function in
flatMapGroupsWithState. Reason: <reason>"
+ ],
+ "sqlState" : "39000"
+ },
"FORBIDDEN_OPERATION" : {
"message" : [
"The operation <statement> is not allowed on the <objectType>:
<objectName>."
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index d56dfebd61ba..766caaab2285 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -18,8 +18,11 @@ package org.apache.spark.sql.execution.streaming
import java.util.concurrent.TimeUnit.NANOSECONDS
+import scala.util.control.NonFatal
+
import org.apache.hadoop.conf.Configuration
+import org.apache.spark.{SparkException, SparkThrowable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
@@ -447,10 +450,33 @@ case class FlatMapGroupsWithStateExec(
hasTimedOut,
watermarkPresent)
- // Call function, get the returned objects and convert them to rows
- val mappedIterator = func(keyObj, valueObjIter, groupState).map { obj =>
- numOutputRows += 1
- getOutputRow(obj)
+ def withUserFuncExceptionHandling[T](func: => T): T = {
+ try {
+ func
+ } catch {
+ case NonFatal(e) if !e.isInstanceOf[SparkThrowable] =>
+ throw FlatMapGroupsWithStateUserFuncException(e)
+ case f: Throwable =>
+ throw f
+ }
+ }
+
+ val mappedIterator = withUserFuncExceptionHandling {
+ func(keyObj, valueObjIter, groupState).map { obj =>
+ numOutputRows += 1
+ getOutputRow(obj)
+ }
+ }
+
+ // Wrap user-provided fns with error handling
+ val wrappedMappedIterator = new Iterator[InternalRow] {
+ override def hasNext: Boolean = {
+ withUserFuncExceptionHandling(mappedIterator.hasNext)
+ }
+
+ override def next(): InternalRow = {
+ withUserFuncExceptionHandling(mappedIterator.next())
+ }
}
// When the iterator is consumed, then write changes to state
@@ -472,7 +498,9 @@ case class FlatMapGroupsWithStateExec(
}
// Return an iterator of rows such that fully consumed, the updated
state value will be saved
- CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator,
onIteratorCompletion)
+ CompletionIterator[InternalRow, Iterator[InternalRow]](
+ wrappedMappedIterator, onIteratorCompletion
+ )
}
}
}
@@ -544,3 +572,13 @@ object FlatMapGroupsWithStateExec {
}
}
}
+
+
+/**
+ * Exception that wraps the exception thrown in the user provided function in
Foreach sink.
+ */
+private[sql] case class FlatMapGroupsWithStateUserFuncException(cause:
Throwable)
+ extends SparkException(
+ errorClass = "FLATMAPGROUPSWITHSTATE_USER_FUNCTION_ERROR",
+ messageParameters = Map("reason" ->
Option(cause.getMessage).getOrElse("")),
+ cause = cause)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index 45a80a210fce..f3ef73c6af5f 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -23,7 +23,6 @@ import java.sql.Timestamp
import org.apache.commons.io.FileUtils
import org.scalatest.exceptions.TestFailedException
-import org.apache.spark.SparkException
import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction
import org.apache.spark.sql.{DataFrame, Encoder}
import org.apache.spark.sql.catalyst.InternalRow
@@ -635,6 +634,72 @@ class FlatMapGroupsWithStateSuite extends
StateStoreMetricsTest {
)
}
+ testWithAllStateVersions("[SPARK-49474] flatMapGroupsWithState - user NPE is
classified") {
+ // Throws NPE
+ val stateFunc = (_: String, _: Iterator[String], _:
GroupState[RunningCount]) => {
+ throw new NullPointerException()
+ // Need to return an iterator for compilation to get types
+ Iterator(1)
+ }
+
+ val inputData = MemoryStream[String]
+ val result =
+ inputData.toDS()
+ .groupByKey(x => x)
+ .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc)
+
+ testStream(result, Update)(
+ AddData(inputData, "a"),
+ ExpectFailure[FlatMapGroupsWithStateUserFuncException]()
+ )
+ }
+
+ testWithAllStateVersions(
+ "[SPARK-49474] flatMapGroupsWithState - null user iterator error is
classified") {
+ // Returns null, will throw NPE when method is called on it
+ val stateFunc = (_: String, _: Iterator[String], _:
GroupState[RunningCount]) => {
+ null.asInstanceOf[Iterator[Int]]
+ }
+
+ val inputData = MemoryStream[String]
+ val result =
+ inputData.toDS()
+ .groupByKey(x => x)
+ .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc)
+
+ testStream(result, Update)(
+ AddData(inputData, "a"),
+ ExpectFailure[FlatMapGroupsWithStateUserFuncException]()
+ )
+ }
+
+ testWithAllStateVersions(
+ "[SPARK-49474] flatMapGroupsWithState - NPE from user iterator is
classified") {
+ // Returns iterator that throws NPE when next is called
+ val stateFunc = (_: String, _: Iterator[String], _:
GroupState[RunningCount]) => {
+ new Iterator[Int] {
+ override def hasNext: Boolean = {
+ true
+ }
+
+ override def next(): Int = {
+ throw new NullPointerException()
+ }
+ }
+ }
+
+ val inputData = MemoryStream[String]
+ val result =
+ inputData.toDS()
+ .groupByKey(x => x)
+ .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc)
+
+ testStream(result, Update)(
+ AddData(inputData, "a"),
+ ExpectFailure[FlatMapGroupsWithStateUserFuncException]()
+ )
+ }
+
test("mapGroupsWithState - streaming") {
// Function to maintain running count up to 2, and then remove the count
// Returns the data and the count (-1 if count reached beyond 2 and state
was just removed)
@@ -816,7 +881,8 @@ class FlatMapGroupsWithStateSuite extends
StateStoreMetricsTest {
CheckNewAnswer(("a", 2L)),
setFailInTask(true),
AddData(inputData, "a"),
- ExpectFailure[SparkException](), // task should fail but should not
increment count
+ // task should fail but should not increment count
+ ExpectFailure[FlatMapGroupsWithStateUserFuncException](),
setFailInTask(false),
StartStream(),
CheckNewAnswer(("a", 3L)) // task should not fail, and should show
correct count
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]