HeartSaVioR commented on code in PR #48686:
URL: https://github.com/apache/spark/pull/48686#discussion_r1833952122


##########
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala:
##########
@@ -503,4 +567,146 @@ class TransformWithStateInitialStateSuite extends 
StateStoreMetricsTest
       )
     }
   }
+
+  Seq(true, false).foreach { flattenOption =>
+    Seq(("5", "2"), ("5", "8"), ("5", "5")).foreach { partitions =>
+      test("state data source reader dataframe as initial state " +
+        s"(flatten option=$flattenOption, shuffle partition for 1st 
stream=${partitions._1}, " +
+        s"shuffle partition for 1st stream=${partitions._2})") {
+        withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+          classOf[RocksDBStateStoreProvider].getName) {
+          withTempPaths(2) { checkpointDirs =>
+            SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key, 
partitions._1)
+            val inputData = MemoryStream[String]
+            val result = inputData.toDS()
+              .groupByKey(x => x)
+              .transformWithState(new 
InitialStatefulProcessorWithStateDataSource(),
+                TimeMode.None(),
+                OutputMode.Update())
+
+            testStream(result, OutputMode.Update())(
+              StartStream(checkpointLocation = 
checkpointDirs(0).getCanonicalPath),
+              AddData(inputData, "a", "b"),
+              CheckNewAnswer(("a", "1"), ("b", "1")),
+              AddData(inputData, "a", "b", "a"),
+              CheckNewAnswer(("a", "3"), ("b", "2"))
+            )
+
+            // We are trying to mimic a use case where users load all state 
data rows
+            // from a previous tws query as initial state and start a new tws 
query.
+            // In this use case, users will need to create a single dataframe 
with
+            // all the state rows from different state variables with 
different schema.
+            // We can only read from one state variable from one state data 
source reader
+            // query, and they are of different schema. We will get one 
dataframe from each
+            // state variable, and we union them together into a single 
dataframe.
+            val valueDf = spark.read
+              .format("statestore")
+              .option(StateSourceOptions.PATH, 
checkpointDirs(0).getAbsolutePath)
+              .option(StateSourceOptions.STATE_VAR_NAME, "testVal")
+              .load()
+
+            val listDf = spark.read
+              .format("statestore")
+              .option(StateSourceOptions.PATH, 
checkpointDirs(0).getAbsolutePath)
+              .option(StateSourceOptions.STATE_VAR_NAME, "testList")
+              .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, 
flattenOption)
+              .load()
+
+            val mapDf = spark.read
+              .format("statestore")
+              .option(StateSourceOptions.PATH, 
checkpointDirs(0).getAbsolutePath)
+              .option(StateSourceOptions.STATE_VAR_NAME, "testMap")
+              .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, 
flattenOption)
+              .load()
+
+            // create a df where each row contains all value, list, map state 
rows
+            // fill the missing column with null.
+            SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key, 
partitions._2)
+            val inputData2 = MemoryStream[String]
+            val query = startQueryWithDataSourceDataframeAsInitState(
+              flattenOption, valueDf, listDf, mapDf, inputData2)
+
+            testStream(query, OutputMode.Update())(
+              StartStream(checkpointLocation = 
checkpointDirs(1).getCanonicalPath),
+              // check initial state is updated for state vars
+              AddData(inputData2, "c"),
+              CheckNewAnswer(("c", "1")),
+              Execute { _ =>
+                // value state var is checked by the stateful processor output
+                val listDf2 = spark.read
+                  .format("statestore")
+                  .option(StateSourceOptions.PATH, 
checkpointDirs(1).getAbsolutePath)
+                  .option(StateSourceOptions.STATE_VAR_NAME, "testList")
+                  .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, 
flattenOption)
+                  .load()
+                  .drop("partition_id")
+                  .filter(col("key.value") =!= "c")
+                val mapDf2 = spark.read
+                  .format("statestore")
+                  .option(StateSourceOptions.PATH, 
checkpointDirs(1).getAbsolutePath)
+                  .option(StateSourceOptions.STATE_VAR_NAME, "testMap")
+                  .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, 
flattenOption)
+                  .load()
+                  .drop("partition_id")
+                  .filter(col("key.value") =!= "c")
+
+                // simple validation on initial state process
+                assert(listDf2.count() == listDf.count())

Review Comment:
   nit: shall we leverage `checkAnswer` between listDf and listDf2, and mapDf 
and mapDf2? Since we exclude the state from the new input, it should be exactly 
the same. We need to remove `partitionId` from both before comparison.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to