rahulsmahadev commented on a change in pull request #33093:
URL: https://github.com/apache/spark/pull/33093#discussion_r662635048
##########
File path:
sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
##########
@@ -1268,12 +1267,252 @@ class FlatMapGroupsWithStateSuite extends
StateStoreMetricsTest {
assert(e.getMessage === "The output mode of function should be append or
update")
}
+ import testImplicits._
+
+ /**
+ * FlatMapGroupsWithState function that returns the key, value as passed to
it
+ * along with the updated state. The state is incremented for every value.
+ */
+ val flatMapGroupsWithStateFunc =
+ (key: String, values: Iterator[String], state: GroupState[RunningCount])
=> {
+ val valList = values.toSeq
+ if (valList.isEmpty) {
+ // When the function is called on just the initial state make sure the
other fields
+ // are set correctly
+ assert(state.exists)
+ assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 }
+ assertCannotGetWatermark { state.getCurrentWatermarkMs() }
+ assert(!state.hasTimedOut)
+ }
+ val count = state.getOption.map(_.count).getOrElse(0L) + valList.size
+ // We need to check if not explicitly calling update will still save the
init state or not
+ if (valList.nonEmpty || state.getOption.map(_.count).getOrElse(0L) !=
2L) {
+ // this is not reached when valList is empty and the state count is 2
+ state.update(new RunningCount(count))
+ }
+ Iterator((key, valList, count.toString))
+ }
+
+ Seq("1", "2", "6").foreach { shufflePartitions =>
+ testWithAllStateVersions(s"flatMapGroupsWithState - initial " +
+ s"state - all cases - shuffle partitions ${shufflePartitions}") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> shufflePartitions) {
+ // We will test them on different shuffle partition configuration to
make sure the
+ // grouping by key will still work. On higher number of shuffle
partitions its possible
+ // that all keys end up on different partitions.
+ val initialState: Dataset[(String, RunningCount)] = Seq(
+ ("keyInStateAndData-1", new RunningCount(1)),
+ ("keyInStateAndData-2", new RunningCount(2)), // state.update will
not be called
+ ("keyOnlyInState-2", new RunningCount(2)),
+ ("keyOnlyInState-1", new RunningCount(1))
+ ).toDS()
+
+ val it = initialState.groupByKey(x => x._1).mapValues(_._2)
+ val inputData = MemoryStream[String]
+ val result =
+ inputData.toDS()
+ .groupByKey(x => x)
+ .flatMapGroupsWithState(
+ Update, GroupStateTimeout.NoTimeout,
it)(flatMapGroupsWithStateFunc)
+
+ testStream(result, Update)(
+ AddData(inputData, "keyOnlyInData", "keyInStateAndData-2"),
+ CheckNewAnswer(
+ ("keyOnlyInState-1", Seq[String](), "1"),
+ ("keyOnlyInState-2", Seq[String](), "2"),
+ ("keyInStateAndData-2", Seq[String]("keyInStateAndData-2"), "3"),
// inc by 1
+ ("keyInStateAndData-1", Seq[String](), "1"),
+ ("keyOnlyInData", Seq[String]("keyOnlyInData"), "1") // inc by 1
+ ),
+ assertNumStateRows(total = 5, updated = 4),
+ // Stop and Start stream to make sure initial state doesn't get
applied again.
+ StopStream,
+ StartStream(),
+ AddData(inputData, "keyInStateAndData-1"),
+ CheckNewAnswer(
+ // state incremented by 1
+ ("keyInStateAndData-1", Seq[String]("keyInStateAndData-1"), "2")
+ ),
+ assertNumStateRows(total = 5, updated = 1),
+ StopStream
+ )
+ }
+ }
+ }
+
+ testWithAllStateVersions("flatMapGroupsWithState - initial state - case
class key") {
+ val stateFunc = (key: User, values: Iterator[User], state:
GroupState[Long]) => {
+ val valList = values.toSeq
+ if (valList.isEmpty) {
+ // When the function is called on just the initial state make sure the
other fields
+ // are set correctly
+ assert(state.exists)
+ assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 }
+ assertCannotGetWatermark { state.getCurrentWatermarkMs() }
+ assert(!state.hasTimedOut)
+ }
+ val count = state.getOption.getOrElse(0L) + valList.size
+ // We need to check if not explicitly calling update will still save the
state or not
+ if (valList.nonEmpty || state.getOption.getOrElse(0L) != 2L) {
+ // this is not reached when valList is empty and the state count is 2
+ state.update(count)
+ }
+ Iterator((key, valList.map(_.name), count.toString))
+ }
+
+ val ds = Seq(
+ (User("keyInStateAndData", "1"), (1L)),
+ (User("keyOnlyInState", "1"), (1L)),
+ (User("keyOnlyInState", "2"), (2L)) // state.update will not be called
on this in the function
+ ).toDS().groupByKey(_._1).mapValues(_._2)
+
+ val inputData = MemoryStream[User]
+ val result =
+ inputData.toDS()
+ .groupByKey(x => x)
+ .flatMapGroupsWithState(Update, NoTimeout(), ds)(stateFunc)
+
+ testStream(result, Update)(
+ AddData(inputData, User("keyInStateAndData", "1"), User("keyOnlyInData",
"1")),
+ CheckNewAnswer(
+ (("keyInStateAndData", "1"), Seq[String]("keyInStateAndData"), "2"),
+ (("keyOnlyInState", "1"), Seq[String](), "1"),
+ (("keyOnlyInState", "2"), Seq[String](), "2"),
+ (("keyOnlyInData", "1"), Seq[String]("keyOnlyInData"), "1")
+ ),
+ assertNumStateRows(total = 4, updated = 3), // (keyOnlyInState, 2) does
not call update()
+ // Stop and Start stream to make sure initial state doesn't get applied
again.
+ StopStream,
+ StartStream(),
+ AddData(inputData, User("keyOnlyInData", "1")),
+ CheckNewAnswer(
+ (("keyOnlyInData", "1"), Seq[String]("keyOnlyInData"), "2")
+ ),
+ assertNumStateRows(total = 4, updated = 1),
+ StopStream
+ )
+ }
+
+ testQuietly("flatMapGroupsWithState - initial state - duplicate keys") {
+ val initialState = Seq(
+ ("a", new RunningCount(2)),
+ ("a", new RunningCount(1))
+ ).toDS().groupByKey(_._1).mapValues(_._2)
+
+ val inputData = MemoryStream[String]
+ val result =
+ inputData.toDS()
+ .groupByKey(x => x)
+ .flatMapGroupsWithState(Update, NoTimeout(),
initialState)(flatMapGroupsWithStateFunc)
+ testStream(result, Update)(
+ AddData(inputData, "a"),
+ ExpectFailure[SparkException] { e =>
+ assert(e.getCause.getMessage.contains("The initial state provided
contained " +
+ "multiple rows(state) with the same key"))
+ }
+ )
+ }
+
+ test("flatMapGroupsWithState - initial state - streaming initial state") {
+ val initialStateData = MemoryStream[(String, RunningCount)]
+ initialStateData.addData(("a", new RunningCount(1)))
+
+ val inputData = MemoryStream[String]
+
+ val result =
+ inputData.toDS()
+ .groupByKey(x => x)
+ .flatMapGroupsWithState(
+ Update, NoTimeout(),
initialStateData.toDS().groupByKey(_._1).mapValues(_._2)
+ )(flatMapGroupsWithStateFunc)
+
+ val e = intercept[AnalysisException] {
+ result.writeStream
+ .format("console")
+ .start()
+ }
+
+ val expectedError = "Non-streaming DataFrame/Dataset is not supported" +
+ " as the initial state in [flatMap|map]GroupsWithState" +
+ " operation on a streaming DataFrame/Dataset"
+ assert(e.message.contains(expectedError))
+ }
+
+ testWithAllStateVersions("mapGroupsWithState - initial state - null key") {
+ val mapGroupsWithStateFunc =
+ (key: String, values: Iterator[String], state:
GroupState[RunningCount]) => {
+ val valList = values.toList
+ val count = state.getOption.map(_.count).getOrElse(0L) + valList.size
+ state.update(new RunningCount(count))
+ (key, state.get.count.toString)
+ }
+ val initialState = Seq(
+ ("key", new RunningCount(5)),
+ (null, new RunningCount(2))
+ ).toDS().groupByKey(_._1).mapValues(_._2)
+
+ val inputData = MemoryStream[String]
+ val result =
+ inputData.toDS()
+ .groupByKey(x => x)
+ .mapGroupsWithState(NoTimeout(), initialState)(mapGroupsWithStateFunc)
+ testStream(result, Update)(
+ AddData(inputData, "key", null),
+ CheckNewAnswer(
+ ("key", "6"), // state is incremented by 1
+ (null, "3") // incremented by 1
+ ),
+ assertNumStateRows(total = 2, updated = 2),
+ StopStream
+ )
+ }
+
+ testWithAllStateVersions("flatMapGroupsWithState - initial state -
processing time timeout") {
+ // function will return -1 on timeout and returns count of the state
otherwise
+ val stateFunc =
+ (key: String, values: Iterator[(String, Long)], state:
GroupState[RunningCount]) => {
+ if (state.hasTimedOut) {
+ state.remove()
+ Iterator((key, "-1"))
+ } else {
+ val count = state.getOption.map(_.count).getOrElse(0L) + values.size
+ state.update(RunningCount(count))
+ state.setTimeoutDuration("10 seconds")
+ Iterator((key, count.toString))
+ }
+ }
+
+ val clock = new StreamManualClock
+ val inputData = MemoryStream[(String, Long)]
+ val initialState = Seq(
+ ("c", new RunningCount(2))
+ ).toDS().groupByKey(_._1).mapValues(_._2)
+ val result =
+ inputData.toDF().toDF("key", "time")
+ .selectExpr("key", "timestamp_seconds(time) as timestamp")
+ .withWatermark("timestamp", "10 second")
+ .as[(String, Long)]
+ .groupByKey(x => x._1)
+ .flatMapGroupsWithState(Update, ProcessingTimeTimeout(),
initialState)(stateFunc)
+
+ testStream(result, Update)(
+ StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock),
+ AddData(inputData, ("a", 1L)),
+ AdvanceManualClock(1 * 1000), // a and c are processed here for the
first time.
+ CheckNewAnswer(("a", "1"), ("c", "2")),
+ AdvanceManualClock(10 * 1000),
+ AddData(inputData, ("b", 1L)), // this will trigger c and a to get timed
out
+ AdvanceManualClock(1 * 1000),
+ CheckNewAnswer(("a", "-1"), ("b", "1"), ("c", "-1"))
+ )
+ }
+
def testWithTimeout(timeoutConf: GroupStateTimeout): Unit = {
test("SPARK-20714: watermark does not fail query when timeout = " +
timeoutConf) {
// Function to maintain running count up to 2, and then remove the count
// Returns the data and the count (-1 if count reached beyond 2 and
state was just removed)
val stateFunc =
- (key: String, values: Iterator[(String, Long)], state:
GroupState[RunningCount]) => {
+ (key: String, values: Iterator[(String, Long)], state:
GroupState[RunningCount]) => {
Review comment:
nit: fix this
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]