HeartSaVioR commented on code in PR #48686:
URL: https://github.com/apache/spark/pull/48686#discussion_r1828626947


##########
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala:
##########
@@ -83,6 +98,49 @@ abstract class StatefulProcessorWithInitialStateTestClass[V]
   }
 }
 
+/**
+ * Stateful processor that will take a union dataframe output from state data 
source reader,
+ * with flattened state data source rows.

Review Comment:
   Shall we describe what this class will do with initial state briefly, 
especially about map state? Mapping the first char to double is a neat 
approach, but they need to read the code to understand what is happening. 
   
   Is it intended to make transformation (changing the key of the map) before 
setting to state variable? I agree users can make transformation into this 
method and aren't restricted to feed the same, but one more thing to track to 
verify the test.
   
   If we just load them as they are into state variables for new query, I'm OK 
to not add extra explanation.



##########
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala:
##########
@@ -497,4 +569,135 @@ class TransformWithStateInitialStateSuite extends 
StateStoreMetricsTest
       )
     }
   }
+
+  Seq(true, false).foreach { flattenOption =>
+    Seq(("5", "2"), ("5", "8"), ("5", "5")).foreach { partitions =>
+      test("state data source reader dataframe as initial state " +
+        s"(flatten option=$flattenOption, shuffle partition for 1st 
stream=${partitions._1}, " +
+        s"shuffle partition for 1st stream=${partitions._2})") {
+        withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+          classOf[RocksDBStateStoreProvider].getName) {
+          withTempDir { checkpointDir =>
+            SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key, 
partitions._1)
+            val inputData = MemoryStream[String]
+            val result = inputData.toDS()
+              .groupByKey(x => x)
+              .transformWithState(new StatefulProcessorWithAllStateVars(),
+                TimeMode.None(),
+                OutputMode.Update())
+
+            testStream(result, OutputMode.Update())(
+              StartStream(checkpointLocation = checkpointDir.getCanonicalPath),
+              AddData(inputData, "a", "b"),
+              CheckNewAnswer(("a", "1"), ("b", "1")),
+              AddData(inputData, "a", "b", "a"),
+              CheckNewAnswer(("a", "3"), ("b", "2"))
+            )
+
+            // We are trying to mimic a use case where users load all state 
data rows
+            // from a previous tws query as initial state and start a new tws 
query.
+            // In this use case, users will need to create a single dataframe 
with
+            // all the state rows from different state variables with 
different schema.
+            // We can only read from one state variable from one state data 
source reader
+            // query, and they are of different schema. We will get one 
dataframe from each
+            // state variable, and we union them together into a single 
dataframe.
+            val valueDf = spark.read
+              .format("statestore")
+              .option(StateSourceOptions.PATH, checkpointDir.getAbsolutePath)
+              .option(StateSourceOptions.STATE_VAR_NAME, "countState")
+              .load()
+
+            val listDf = spark.read
+              .format("statestore")
+              .option(StateSourceOptions.PATH, checkpointDir.getAbsolutePath)
+              .option(StateSourceOptions.STATE_VAR_NAME, "listState")
+              .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, 
flattenOption)
+              .load()
+
+            val mapDf = spark.read
+              .format("statestore")
+              .option(StateSourceOptions.PATH, checkpointDir.getAbsolutePath)
+              .option(StateSourceOptions.STATE_VAR_NAME, "mapState")
+              .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, 
flattenOption)
+              .load()
+
+            // create a df where each row contains all value, list, map state 
rows
+            // fill the missing column with null.
+            SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key, 
partitions._2)
+            val inputData2 = MemoryStream[InitInputRow]
+            val query = startQueryWithDataSourceDataframeAsInitState(
+              flattenOption, valueDf, listDf, mapDf, inputData2)
+
+            testStream(query, OutputMode.Update())(
+              // check initial state is updated for state vars
+              AddData(inputData2,
+                InitInputRow("a", "getOption", 0.0),
+                InitInputRow("a", "getList", 0.0),
+                InitInputRow("a", "getCount", 0.0)),
+              CheckNewAnswer(("a", "getCount", 3.0),
+                ("a", "getList", 1.0), ("a", "getList", 2.0), ("a", "getList", 
3.0),
+                ("a", "getOption", 3.0)),
+              // check we can make updates on state vars after first batch
+              AddData(inputData2, InitInputRow("b", "update", 37.0)),

Review Comment:
   ditto, combine into a single command



##########
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala:
##########
@@ -83,6 +98,49 @@ abstract class StatefulProcessorWithInitialStateTestClass[V]
   }
 }
 
+/**
+ * Stateful processor that will take a union dataframe output from state data 
source reader,
+ * with flattened state data source rows.
+ */
+class InitialStatefulProcessorWithStateDataSource
+  extends StatefulProcessorWithInitialStateTestClass[UnionInitialStateRow] {
+  override def handleInitialState(
+      key: String, initialState: UnionInitialStateRow, timerValues: 
TimerValues): Unit = {
+    if (initialState.value.isDefined) {
+      _valState.update(initialState.value.get.toDouble)
+    } else if (initialState.listValue.isDefined) {
+      _listState.appendValue(initialState.listValue.get.toDouble)
+    } else if (initialState.userMapKey.isDefined) {
+      _mapState.updateValue(
+        (initialState.userMapKey.get.charAt(0) - 'a').toDouble,
+        initialState.userMapValue.get.toInt)
+    }
+  }
+}
+
+/**
+ * Stateful processor that will take a union dataframe output from state data 
source reader,
+ * with composite type state data source rows.

Review Comment:
   ditto



##########
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala:
##########
@@ -497,4 +587,121 @@ class TransformWithStateInitialStateSuite extends 
StateStoreMetricsTest
       )
     }
   }
+
+  testInitialStateWithStateDataSource(true) { (valDf, listDf, mapDf, 
inputData) =>
+    val valueDf = valDf.selectExpr("key.value AS groupingKey", "value.value AS 
value")
+    val flattenListDf = listDf
+      .selectExpr("key.value AS groupingKey", "list_element.value AS 
listValue")
+    val flattenMapDf = mapDf
+      .selectExpr(
+        "key.value AS groupingKey",
+        "user_map_key.value AS userMapKey",
+        "user_map_value.value AS userMapValue")
+    val df_joined =
+      valueDf.unionByName(flattenListDf, true)
+        .unionByName(flattenMapDf, true)
+    val kvDataSet = inputData.toDS().groupByKey(x => x.key)
+    val initDf = df_joined.as[UnionInitialStateRow].groupByKey(x => 
x.groupingKey)
+    kvDataSet.transformWithState(
+      new InitialStatefulProcessorWithStateDataSource(),
+      TimeMode.None(), OutputMode.Append(), initDf).toDF()
+  }
+
+  testInitialStateWithStateDataSource(false) { (valDf, listDf, mapDf, 
inputData) =>
+    val valueDf = valDf.selectExpr("key.value AS groupingKey", "value.value AS 
value")
+    val unflattenListDf = listDf
+      .selectExpr("key.value AS groupingKey",
+        "list_value.value as listValue")
+    val unflattenMapDf = mapDf
+      .selectExpr(
+        "key.value AS groupingKey",
+        "map_from_entries(transform(map_entries(map_value), x -> " +
+          "struct(x.key.value, x.value.value))) as mapValue")
+    val df_joined =
+      valueDf.unionByName(unflattenListDf, true)
+        .unionByName(unflattenMapDf, true)
+    val kvDataSet = inputData.toDS().groupByKey(x => x.key)
+    val initDf = df_joined.as[UnionUnflattenInitialStateRow].groupByKey(x => 
x.groupingKey)
+    kvDataSet.transformWithState(
+      new InitialStatefulProcessorWithUnflattenStateDataSource(),
+      TimeMode.None(), OutputMode.Append(), initDf).toDF()
+  }
+
+  private def testInitialStateWithStateDataSource(
+      flattenOption: Boolean)
+      (startQuery: (DataFrame, DataFrame, DataFrame,
+        MemoryStream[InitInputRow]) => DataFrame): Unit = {
+    Seq(("5", "2"), ("5", "8"), ("5", "5")).foreach { partitions =>
+      test("transformWithStateWithInitialState - state data source reader 
dataframe " +
+        s"as initial state with flatten option set to $flattenOption, the 
first stream and " +
+        s"the second stream is running on shuffle partition number of 
${partitions._1} and " +
+        s"${partitions._2} respectively.") {
+        withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+          classOf[RocksDBStateStoreProvider].getName) {
+          withTempDir { checkpointDir =>
+            SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key, 
partitions._1)
+            val inputData = MemoryStream[String]
+            val result = inputData.toDS()
+              .groupByKey(x => x)
+              .transformWithState(new StatefulProcessorWithAllStateVars(),
+                TimeMode.None(),
+                OutputMode.Update())
+
+            testStream(result, OutputMode.Update())(
+              StartStream(checkpointLocation = checkpointDir.getCanonicalPath),
+              AddData(inputData, "a", "b"),
+              CheckNewAnswer(("a", "1"), ("b", "1")),
+              AddData(inputData, "a", "b", "a"),
+              CheckNewAnswer(("a", "3"), ("b", "2"))
+            )
+
+            // state data source reader for state vars
+            val valueDf = spark.read
+              .format("statestore")
+              .option(StateSourceOptions.PATH, checkpointDir.getAbsolutePath)
+              .option(StateSourceOptions.STATE_VAR_NAME, "countState")
+              .load()
+
+            val listDf = spark.read
+              .format("statestore")
+              .option(StateSourceOptions.PATH, checkpointDir.getAbsolutePath)
+              .option(StateSourceOptions.STATE_VAR_NAME, "listState")
+              .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, 
flattenOption)
+              .load()
+
+            val mapDf = spark.read
+              .format("statestore")
+              .option(StateSourceOptions.PATH, checkpointDir.getAbsolutePath)
+              .option(StateSourceOptions.STATE_VAR_NAME, "mapState")
+              .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, 
flattenOption)
+              .load()
+
+            // create a df where each row contains all value, list, map state 
rows

Review Comment:
   If someone has to verify the test to work correctly, they have to do the 
following
   
   1) learn how StatefulProcessorWithAllStateVars works
   2) go through the input rows and calculate by themselves how the values of 
three state vars would look like
   3) look at how the second query is constructed with initial state from state 
data source reader
   4) learn how InitialStatefulProcessorWithStateDataSource works
   5) reason about the value of state variables for initial state
   6) go through the input rows and calculate by themselves how the values of 
three state vars would look like
   
   I understand the purpose, making the test to be more e2e usage like, but 
this is a unit test, and what we describe for test name is about `ability of 
setting up initial state with the result of state data source reader`. I don't 
expect the test to show up the flexibility of what users can do. That should be 
placed to the examples directory or some blog post with github repo, not here.
   
   I think the test still verifies what it intends to do, if the test just 
shows the ability of migration for checkpoint "with the same logic of the 
query". The test doesn't need for reader to understand two different queries; 
if the test runs for a couple batches, and can continue with another checkpoint 
(meaning batch 0) with retaining state via initial state (and allow further 
modification), that should be enough.
   



##########
sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala:
##########
@@ -121,6 +121,8 @@ private[sql] abstract class 
StatefulProcessorWithInitialState[K, I, O, S]
 
   /**
    * Function that will be invoked only in the first batch for users to 
process initial states.
+   * Allow multiple initial state rows mapping to the same grouping key to 
support integration

Review Comment:
   Please make the method doc to be understandable from users for user-facing 
API. You can leave the comment for the rationale of allowing this separately, 
like `// ` outside of the method doc, to not include that part to the method 
doc.
   
   Please pretend yourself to be a one of users who is just a "starter" for the 
method doc of user facing API.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to