HeartSaVioR commented on code in PR #37907:
URL: https://github.com/apache/spark/pull/37907#discussion_r972579759


##########
sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala:
##########
@@ -78,416 +78,6 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
     }
   }
 
-  test("SPARK-35800: ensure TestGroupState creates instances the same as 
prod") {

Review Comment:
   Lines in L81 to L490 are moved out to GroupStateSuite.



##########
sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala:
##########
@@ -0,0 +1,365 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.{AnalysisException, Dataset, 
KeyValueGroupedDataset}
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Update
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import 
org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper
+import org.apache.spark.sql.internal.SQLConf
+import 
org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.{assertCanGetProcessingTime,
 assertCannotGetWatermark}
+import org.apache.spark.sql.streaming.GroupStateTimeout.{EventTimeTimeout, 
NoTimeout, ProcessingTimeTimeout}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+
+class FlatMapGroupsWithStateWithInitialStateSuite extends 
StateStoreMetricsTest {
+  import testImplicits._
+
+  /**
+   * FlatMapGroupsWithState function that returns the key, value as passed to 
it
+   * along with the updated state. The state is incremented for every value.
+   */
+  val flatMapGroupsWithStateFunc =
+    (key: String, values: Iterator[String], state: GroupState[RunningCount]) 
=> {
+      val valList = values.toSeq
+      if (valList.isEmpty) {
+        // When the function is called on just the initial state make sure the 
other fields
+        // are set correctly
+        assert(state.exists)
+      }
+      assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 }
+      assertCannotGetWatermark { state.getCurrentWatermarkMs() }
+      assert(!state.hasTimedOut)
+      if (key.contains("EventTime")) {
+        state.setTimeoutTimestamp(0, "1 hour")
+      }
+      if (key.contains("ProcessingTime")) {
+        state.setTimeoutDuration("1  hour")
+      }
+      val count = state.getOption.map(_.count).getOrElse(0L) + valList.size
+      // We need to check if not explicitly calling update will still save the 
init state or not
+      if (!key.contains("NoUpdate")) {
+        // this is not reached when valList is empty and the state count is 2
+        state.update(new RunningCount(count))
+      }
+      Iterator((key, valList, count.toString))
+    }
+
+  Seq("1", "2", "6").foreach { shufflePartitions =>
+    testWithAllStateVersions(s"flatMapGroupsWithState - initial " +
+      s"state - all cases - shuffle partitions ${shufflePartitions}") {
+      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> shufflePartitions) {
+        // We will test them on different shuffle partition configuration to 
make sure the
+        // grouping by key will still work. On higher number of shuffle 
partitions its possible
+        // that all keys end up on different partitions.
+        val initialState: Dataset[(String, RunningCount)] = Seq(
+          ("keyInStateAndData-1", new RunningCount(1)),
+          ("keyInStateAndData-2", new RunningCount(2)),
+          ("keyNoUpdate", new RunningCount(2)), // state.update will not be 
called
+          ("keyOnlyInState-1", new RunningCount(1))
+        ).toDS()
+
+        val it = initialState.groupByKey(x => x._1).mapValues(_._2)
+        val inputData = MemoryStream[String]
+        val result =
+          inputData.toDS()
+            .groupByKey(x => x)
+            .flatMapGroupsWithState(
+              Update, GroupStateTimeout.NoTimeout, 
it)(flatMapGroupsWithStateFunc)
+
+        testStream(result, Update)(
+          AddData(inputData, "keyOnlyInData", "keyInStateAndData-2"),
+          CheckNewAnswer(
+            ("keyOnlyInState-1", Seq[String](), "1"),
+            ("keyNoUpdate", Seq[String](), "2"), // update will not be called
+            ("keyInStateAndData-2", Seq[String]("keyInStateAndData-2"), "3"), 
// inc by 1
+            ("keyInStateAndData-1", Seq[String](), "1"),
+            ("keyOnlyInData", Seq[String]("keyOnlyInData"), "1") // inc by 1
+          ),
+          assertNumStateRows(total = 5, updated = 4),
+          // Stop and Start stream to make sure initial state doesn't get 
applied again.
+          StopStream,
+          StartStream(),
+          AddData(inputData, "keyInStateAndData-1"),
+          CheckNewAnswer(
+            // state incremented by 1
+            ("keyInStateAndData-1", Seq[String]("keyInStateAndData-1"), "2")
+          ),
+          assertNumStateRows(total = 5, updated = 1),
+          StopStream
+        )
+      }
+    }
+  }
+
+  testWithAllStateVersions("flatMapGroupsWithState - initial state - case 
class key") {
+    val stateFunc = (key: User, values: Iterator[User], state: 
GroupState[Long]) => {
+      val valList = values.toSeq
+      if (valList.isEmpty) {
+        // When the function is called on just the initial state make sure the 
other fields
+        // are set correctly
+        assert(state.exists)
+      }
+      assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 }
+      assertCannotGetWatermark { state.getCurrentWatermarkMs() }
+      assert(!state.hasTimedOut)
+      val count = state.getOption.getOrElse(0L) + valList.size
+      // We need to check if not explicitly calling update will still save the 
state or not
+      if (!key.name.contains("NoUpdate")) {
+        // this is not reached when valList is empty and the state count is 2
+        state.update(count)
+      }
+      Iterator((key, valList.map(_.name), count.toString))
+    }
+
+    val ds = Seq(
+      (User("keyInStateAndData", "1"), (1L)),
+      (User("keyOnlyInState", "1"), (1L)),
+      (User("keyNoUpdate", "2"), (2L)) // state.update will not be called on 
this in the function
+    ).toDS().groupByKey(_._1).mapValues(_._2)
+
+    val inputData = MemoryStream[User]
+    val result =
+      inputData.toDS()
+        .groupByKey(x => x)
+        .flatMapGroupsWithState(Update, NoTimeout(), ds)(stateFunc)
+
+    testStream(result, Update)(
+      AddData(inputData, User("keyInStateAndData", "1"), User("keyOnlyInData", 
"1")),
+      CheckNewAnswer(
+        (("keyInStateAndData", "1"), Seq[String]("keyInStateAndData"), "2"),
+        (("keyOnlyInState", "1"), Seq[String](), "1"),
+        (("keyNoUpdate", "2"), Seq[String](), "2"),
+        (("keyOnlyInData", "1"), Seq[String]("keyOnlyInData"), "1")
+      ),
+      assertNumStateRows(total = 4, updated = 3), // (keyOnlyInState, 2) does 
not call update()
+      // Stop and Start stream to make sure initial state doesn't get applied 
again.
+      StopStream,
+      StartStream(),
+      AddData(inputData, User("keyOnlyInData", "1")),
+      CheckNewAnswer(
+        (("keyOnlyInData", "1"), Seq[String]("keyOnlyInData"), "2")
+      ),
+      assertNumStateRows(total = 4, updated = 1),
+      StopStream
+    )
+  }
+
+  testQuietly("flatMapGroupsWithState - initial state - duplicate keys") {
+    val initialState = Seq(
+      ("a", new RunningCount(2)),
+      ("a", new RunningCount(1))
+    ).toDS().groupByKey(_._1).mapValues(_._2)
+
+    val inputData = MemoryStream[String]
+    val result =
+      inputData.toDS()
+        .groupByKey(x => x)
+        .flatMapGroupsWithState(Update, NoTimeout(), 
initialState)(flatMapGroupsWithStateFunc)
+    testStream(result, Update)(
+      AddData(inputData, "a"),
+      ExpectFailure[SparkException] { e =>
+        assert(e.getCause.getMessage.contains("The initial state provided 
contained " +
+          "multiple rows(state) with the same key"))
+      }
+    )
+  }
+
+  Seq(NoTimeout(), EventTimeTimeout(), ProcessingTimeTimeout()).foreach { 
timeout =>
+    test(s"flatMapGroupsWithState - initial state - batch mode - timeout 
${timeout}") {
+      // We will test them on different shuffle partition configuration to 
make sure the
+      // grouping by key will still work. On higher number of shuffle 
partitions its possible
+      // that all keys end up on different partitions.
+      val initialState = Seq(
+        (s"keyInStateAndData-1-$timeout", new RunningCount(1)),
+        ("keyInStateAndData-2", new RunningCount(2)),
+        ("keyNoUpdate", new RunningCount(2)), // state.update will not be 
called
+        ("keyOnlyInState-1", new RunningCount(1))
+      ).toDS().groupByKey(x => x._1).mapValues(_._2)
+
+      val inputData = Seq(
+        ("keyOnlyInData"), ("keyInStateAndData-2")
+      )
+      val result = inputData.toDS().groupByKey(x => x)
+        .flatMapGroupsWithState(
+          Update, timeout, initialState)(flatMapGroupsWithStateFunc)
+
+      val expected = Seq(
+        ("keyOnlyInState-1", Seq[String](), "1"),
+        ("keyNoUpdate", Seq[String](), "2"), // update will not be called
+        ("keyInStateAndData-2", Seq[String]("keyInStateAndData-2"), "3"), // 
inc by 1
+        (s"keyInStateAndData-1-$timeout", Seq[String](), "1"),
+        ("keyOnlyInData", Seq[String]("keyOnlyInData"), "1") // inc by 1
+      ).toDF()
+      checkAnswer(result.toDF(), expected)
+    }
+  }
+
+  testQuietly("flatMapGroupsWithState - initial state - batch mode - duplicate 
state") {
+    val initialState = Seq(
+      ("a", new RunningCount(1)),
+      ("a", new RunningCount(2))
+    ).toDS().groupByKey(x => x._1).mapValues(_._2)
+
+    val e = intercept[SparkException] {
+      Seq("a", "b").toDS().groupByKey(x => x)
+        .flatMapGroupsWithState(Update, NoTimeout(), 
initialState)(flatMapGroupsWithStateFunc)
+        .show()
+    }
+    assert(e.getMessage.contains(
+      "The initial state provided contained multiple rows(state) with the same 
key." +
+        " Make sure to de-duplicate the initial state before passing it."))
+  }
+
+  testQuietly("flatMapGroupsWithState - initial state - streaming initial 
state") {
+    val initialStateData = MemoryStream[(String, RunningCount)]
+    initialStateData.addData(("a", new RunningCount(1)))
+
+    val inputData = MemoryStream[String]
+
+    val result =
+      inputData.toDS()
+        .groupByKey(x => x)
+        .flatMapGroupsWithState(
+          Update, NoTimeout(), 
initialStateData.toDS().groupByKey(_._1).mapValues(_._2)
+        )(flatMapGroupsWithStateFunc)
+
+    val e = intercept[AnalysisException] {
+      result.writeStream
+        .format("console")
+        .start()
+    }
+
+    val expectedError = "Non-streaming DataFrame/Dataset is not supported" +
+      " as the initial state in [flatMap|map]GroupsWithState" +
+      " operation on a streaming DataFrame/Dataset"
+    assert(e.message.contains(expectedError))
+  }
+
+  test("flatMapGroupsWithState - initial state - initial state has 
flatMapGroupsWithState") {
+    val initialStateDS = Seq(("keyInStateAndData", new RunningCount(1))).toDS()
+    val initialState: KeyValueGroupedDataset[String, RunningCount] =
+      initialStateDS.groupByKey(_._1).mapValues(_._2)
+        .mapGroupsWithState(
+          GroupStateTimeout.NoTimeout())(
+          (key: String, values: Iterator[RunningCount], state: 
GroupState[Boolean]) => {
+            (key, values.next())
+          }
+        ).groupByKey(_._1).mapValues(_._2)
+
+    val inputData = MemoryStream[String]
+
+    val result =
+      inputData.toDS()
+        .groupByKey(x => x)
+        .flatMapGroupsWithState(
+          Update, NoTimeout(), initialState
+        )(flatMapGroupsWithStateFunc)
+
+    testStream(result, Update)(
+      AddData(inputData, "keyInStateAndData"),
+      CheckNewAnswer(
+        ("keyInStateAndData", Seq[String]("keyInStateAndData"), "2")
+      ),
+      StopStream
+    )
+  }
+
+  testWithAllStateVersions("mapGroupsWithState - initial state - null key") {
+    val mapGroupsWithStateFunc =
+      (key: String, values: Iterator[String], state: GroupState[RunningCount]) 
=> {
+        val valList = values.toList
+        val count = state.getOption.map(_.count).getOrElse(0L) + valList.size
+        state.update(new RunningCount(count))
+        (key, state.get.count.toString)
+      }
+    val initialState = Seq(
+      ("key", new RunningCount(5)),
+      (null, new RunningCount(2))
+    ).toDS().groupByKey(_._1).mapValues(_._2)
+
+    val inputData = MemoryStream[String]
+    val result =
+      inputData.toDS()
+        .groupByKey(x => x)
+        .mapGroupsWithState(NoTimeout(), initialState)(mapGroupsWithStateFunc)
+    testStream(result, Update)(
+      AddData(inputData, "key", null),
+      CheckNewAnswer(
+        ("key", "6"), // state is incremented by 1
+        (null, "3") // incremented by 1
+      ),
+      assertNumStateRows(total = 2, updated = 2),
+      StopStream
+    )
+  }
+
+  testWithAllStateVersions("flatMapGroupsWithState - initial state - 
processing time timeout") {
+    // function will return -1 on timeout and returns count of the state 
otherwise
+    val stateFunc =
+      (key: String, values: Iterator[(String, Long)], state: 
GroupState[RunningCount]) => {
+        if (state.hasTimedOut) {
+          state.remove()
+          Iterator((key, "-1"))
+        } else {
+          val count = state.getOption.map(_.count).getOrElse(0L) + values.size
+          state.update(RunningCount(count))
+          state.setTimeoutDuration("10 seconds")
+          Iterator((key, count.toString))
+        }
+      }
+
+    val clock = new StreamManualClock
+    val inputData = MemoryStream[(String, Long)]
+    val initialState = Seq(
+      ("c", new RunningCount(2))
+    ).toDS().groupByKey(_._1).mapValues(_._2)
+    val result =
+      inputData.toDF().toDF("key", "time")
+        .selectExpr("key", "timestamp_seconds(time) as timestamp")
+        .withWatermark("timestamp", "10 second")
+        .as[(String, Long)]
+        .groupByKey(x => x._1)
+        .flatMapGroupsWithState(Update, ProcessingTimeTimeout(), 
initialState)(stateFunc)
+
+    testStream(result, Update)(
+      StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock),
+      AddData(inputData, ("a", 1L)),
+      AdvanceManualClock(1 * 1000), // a and c are processed here for the 
first time.
+      CheckNewAnswer(("a", "1"), ("c", "2")),
+      AdvanceManualClock(10 * 1000),
+      AddData(inputData, ("b", 1L)), // this will trigger c and a to get timed 
out
+      AdvanceManualClock(1 * 1000),
+      CheckNewAnswer(("a", "-1"), ("b", "1"), ("c", "-1"))
+    )
+  }
+
+  def testWithAllStateVersions(name: String)(func: => Unit): Unit = {

Review Comment:
   This is only one duplication we will have after refactoring. I just copied 
over the code rather than dealing with inheritance / util class as it's just 
about less than 10 lines.



##########
sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala:
##########
@@ -1268,258 +858,6 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
     assert(e.getMessage === "The output mode of function should be append or 
update")
   }
 
-  import testImplicits._

Review Comment:
   Lines in L1271 to L1635 (except a single testcase which isn't related) are 
moved out to FlatMapGroupsWithStateWithInitialStateSuite. L1790 to L1809 are 
also moved out as they are only referred by L1271 to L1635.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to