This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9a7b6e5c31bf [SPARK-49474][SS] Classify Error class for 
FlatMapGroupsWithState user function error
9a7b6e5c31bf is described below

commit 9a7b6e5c31bfcca4283ed6bc22df10b743e9a470
Author: Livia Zhu <[email protected]>
AuthorDate: Thu Sep 5 16:24:01 2024 +0900

    [SPARK-49474][SS] Classify Error class for FlatMapGroupsWithState user 
function error
    
    ### What changes were proposed in this pull request?
    
    Add new error classification for errors occurring in the user function that 
is used in FlatMapGroupsWithState.
    
    ### Why are the changes needed?
    
    The user provided function can throw any type of error. Using the new error 
framework for better error messages and classification.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, better error message with error class for Foreach sink user function 
failures.
    
    ### How was this patch tested?
    
    Updated existing tests and added new unit test in 
FlatMapGroupsWithStateSuite.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #47940 from 
liviazhu-db/liviazhu-db/classify-flatmapgroupswithstate-error.
    
    Authored-by: Livia Zhu <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 .../src/main/resources/error/error-conditions.json |  6 ++
 .../streaming/FlatMapGroupsWithStateExec.scala     | 48 +++++++++++++--
 .../streaming/FlatMapGroupsWithStateSuite.scala    | 70 +++++++++++++++++++++-
 3 files changed, 117 insertions(+), 7 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index e2725a98a63b..96105c967225 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -1457,6 +1457,12 @@
     ],
     "sqlState" : "42704"
   },
+  "FLATMAPGROUPSWITHSTATE_USER_FUNCTION_ERROR" : {
+    "message" : [
+      "An error occurred in the user provided function in 
flatMapGroupsWithState. Reason: <reason>"
+    ],
+    "sqlState" : "39000"
+  },
   "FORBIDDEN_OPERATION" : {
     "message" : [
       "The operation <statement> is not allowed on the <objectType>: 
<objectName>."
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index d56dfebd61ba..766caaab2285 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -18,8 +18,11 @@ package org.apache.spark.sql.execution.streaming
 
 import java.util.concurrent.TimeUnit.NANOSECONDS
 
+import scala.util.control.NonFatal
+
 import org.apache.hadoop.conf.Configuration
 
+import org.apache.spark.{SparkException, SparkThrowable}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
@@ -447,10 +450,33 @@ case class FlatMapGroupsWithStateExec(
         hasTimedOut,
         watermarkPresent)
 
-      // Call function, get the returned objects and convert them to rows
-      val mappedIterator = func(keyObj, valueObjIter, groupState).map { obj =>
-        numOutputRows += 1
-        getOutputRow(obj)
+      def withUserFuncExceptionHandling[T](func: => T): T = {
+        try {
+          func
+        } catch {
+          case NonFatal(e) if !e.isInstanceOf[SparkThrowable] =>
+            throw FlatMapGroupsWithStateUserFuncException(e)
+          case f: Throwable =>
+            throw f
+        }
+      }
+
+      val mappedIterator = withUserFuncExceptionHandling {
+        func(keyObj, valueObjIter, groupState).map { obj =>
+          numOutputRows += 1
+          getOutputRow(obj)
+        }
+      }
+
+      // Wrap user-provided fns with error handling
+      val wrappedMappedIterator = new Iterator[InternalRow] {
+        override def hasNext: Boolean = {
+          withUserFuncExceptionHandling(mappedIterator.hasNext)
+        }
+
+        override def next(): InternalRow = {
+          withUserFuncExceptionHandling(mappedIterator.next())
+        }
       }
 
       // When the iterator is consumed, then write changes to state
@@ -472,7 +498,9 @@ case class FlatMapGroupsWithStateExec(
       }
 
       // Return an iterator of rows such that fully consumed, the updated 
state value will be saved
-      CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, 
onIteratorCompletion)
+      CompletionIterator[InternalRow, Iterator[InternalRow]](
+        wrappedMappedIterator, onIteratorCompletion
+      )
     }
   }
 }
@@ -544,3 +572,13 @@ object FlatMapGroupsWithStateExec {
     }
   }
 }
+
+
+/**
+ * Exception that wraps the exception thrown in the user provided function in 
Foreach sink.
+ */
+private[sql] case class FlatMapGroupsWithStateUserFuncException(cause: 
Throwable)
+  extends SparkException(
+    errorClass = "FLATMAPGROUPSWITHSTATE_USER_FUNCTION_ERROR",
+    messageParameters = Map("reason" -> 
Option(cause.getMessage).getOrElse("")),
+    cause = cause)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index 45a80a210fce..f3ef73c6af5f 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -23,7 +23,6 @@ import java.sql.Timestamp
 import org.apache.commons.io.FileUtils
 import org.scalatest.exceptions.TestFailedException
 
-import org.apache.spark.SparkException
 import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction
 import org.apache.spark.sql.{DataFrame, Encoder}
 import org.apache.spark.sql.catalyst.InternalRow
@@ -635,6 +634,72 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
     )
   }
 
+  testWithAllStateVersions("[SPARK-49474] flatMapGroupsWithState - user NPE is 
classified") {
+    // Throws NPE
+    val stateFunc = (_: String, _: Iterator[String], _: 
GroupState[RunningCount]) => {
+      throw new NullPointerException()
+      // Need to return an iterator for compilation to get types
+      Iterator(1)
+    }
+
+    val inputData = MemoryStream[String]
+    val result =
+      inputData.toDS()
+        .groupByKey(x => x)
+        .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc)
+
+    testStream(result, Update)(
+      AddData(inputData, "a"),
+      ExpectFailure[FlatMapGroupsWithStateUserFuncException]()
+    )
+  }
+
+  testWithAllStateVersions(
+    "[SPARK-49474] flatMapGroupsWithState - null user iterator error is 
classified") {
+    // Returns null, will throw NPE when method is called on it
+    val stateFunc = (_: String, _: Iterator[String], _: 
GroupState[RunningCount]) => {
+      null.asInstanceOf[Iterator[Int]]
+    }
+
+    val inputData = MemoryStream[String]
+    val result =
+      inputData.toDS()
+        .groupByKey(x => x)
+        .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc)
+
+    testStream(result, Update)(
+      AddData(inputData, "a"),
+      ExpectFailure[FlatMapGroupsWithStateUserFuncException]()
+    )
+  }
+
+  testWithAllStateVersions(
+    "[SPARK-49474] flatMapGroupsWithState - NPE from user iterator is 
classified") {
+    // Returns iterator that throws NPE when next is called
+    val stateFunc = (_: String, _: Iterator[String], _: 
GroupState[RunningCount]) => {
+      new Iterator[Int] {
+        override def hasNext: Boolean = {
+          true
+        }
+
+        override def next(): Int = {
+          throw new NullPointerException()
+        }
+      }
+    }
+
+    val inputData = MemoryStream[String]
+    val result =
+      inputData.toDS()
+        .groupByKey(x => x)
+        .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc)
+
+    testStream(result, Update)(
+      AddData(inputData, "a"),
+      ExpectFailure[FlatMapGroupsWithStateUserFuncException]()
+    )
+  }
+
   test("mapGroupsWithState - streaming") {
     // Function to maintain running count up to 2, and then remove the count
     // Returns the data and the count (-1 if count reached beyond 2 and state 
was just removed)
@@ -816,7 +881,8 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
       CheckNewAnswer(("a", 2L)),
       setFailInTask(true),
       AddData(inputData, "a"),
-      ExpectFailure[SparkException](),   // task should fail but should not 
increment count
+      // task should fail but should not increment count
+      ExpectFailure[FlatMapGroupsWithStateUserFuncException](),
       setFailInTask(false),
       StartStream(),
       CheckNewAnswer(("a", 3L))     // task should not fail, and should show 
correct count


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to