HeartSaVioR commented on code in PR #48686:
URL: https://github.com/apache/spark/pull/48686#discussion_r1828626947
##########
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala:
##########
@@ -83,6 +98,49 @@ abstract class StatefulProcessorWithInitialStateTestClass[V]
}
}
+/**
+ * Stateful processor that will take a union dataframe output from state data
source reader,
+ * with flattened state data source rows.
Review Comment:
Shall we describe what this class will do with initial state briefly,
especially about map state? Mapping the first char to double is a neat
approach, but they need to read the code to understand what is happening.
Is it intended to make transformation (changing the key of the map) before
setting to state variable? I agree users can make transformation into this
method and aren't restricted to feed the same, but one more thing to track to
verify the test.
If we just load them as they are into state variables for new query, I'm OK
to not add extra explanation.
##########
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala:
##########
@@ -497,4 +569,135 @@ class TransformWithStateInitialStateSuite extends
StateStoreMetricsTest
)
}
}
+
+ Seq(true, false).foreach { flattenOption =>
+ Seq(("5", "2"), ("5", "8"), ("5", "5")).foreach { partitions =>
+ test("state data source reader dataframe as initial state " +
+ s"(flatten option=$flattenOption, shuffle partition for 1st
stream=${partitions._1}, " +
+ s"shuffle partition for 1st stream=${partitions._2})") {
+ withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+ classOf[RocksDBStateStoreProvider].getName) {
+ withTempDir { checkpointDir =>
+ SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key,
partitions._1)
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new StatefulProcessorWithAllStateVars(),
+ TimeMode.None(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = checkpointDir.getCanonicalPath),
+ AddData(inputData, "a", "b"),
+ CheckNewAnswer(("a", "1"), ("b", "1")),
+ AddData(inputData, "a", "b", "a"),
+ CheckNewAnswer(("a", "3"), ("b", "2"))
+ )
+
+ // We are trying to mimic a use case where users load all state
data rows
+ // from a previous tws query as initial state and start a new tws
query.
+ // In this use case, users will need to create a single dataframe
with
+ // all the state rows from different state variables with
different schema.
+ // We can only read from one state variable from one state data
source reader
+ // query, and they are of different schema. We will get one
dataframe from each
+ // state variable, and we union them together into a single
dataframe.
+ val valueDf = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, checkpointDir.getAbsolutePath)
+ .option(StateSourceOptions.STATE_VAR_NAME, "countState")
+ .load()
+
+ val listDf = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, checkpointDir.getAbsolutePath)
+ .option(StateSourceOptions.STATE_VAR_NAME, "listState")
+ .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES,
flattenOption)
+ .load()
+
+ val mapDf = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, checkpointDir.getAbsolutePath)
+ .option(StateSourceOptions.STATE_VAR_NAME, "mapState")
+ .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES,
flattenOption)
+ .load()
+
+ // create a df where each row contains all value, list, map state
rows
+ // fill the missing column with null.
+ SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key,
partitions._2)
+ val inputData2 = MemoryStream[InitInputRow]
+ val query = startQueryWithDataSourceDataframeAsInitState(
+ flattenOption, valueDf, listDf, mapDf, inputData2)
+
+ testStream(query, OutputMode.Update())(
+ // check initial state is updated for state vars
+ AddData(inputData2,
+ InitInputRow("a", "getOption", 0.0),
+ InitInputRow("a", "getList", 0.0),
+ InitInputRow("a", "getCount", 0.0)),
+ CheckNewAnswer(("a", "getCount", 3.0),
+ ("a", "getList", 1.0), ("a", "getList", 2.0), ("a", "getList",
3.0),
+ ("a", "getOption", 3.0)),
+ // check we can make updates on state vars after first batch
+ AddData(inputData2, InitInputRow("b", "update", 37.0)),
Review Comment:
ditto, combine into a single command
##########
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala:
##########
@@ -83,6 +98,49 @@ abstract class StatefulProcessorWithInitialStateTestClass[V]
}
}
+/**
+ * Stateful processor that will take a union dataframe output from state data
source reader,
+ * with flattened state data source rows.
+ */
+class InitialStatefulProcessorWithStateDataSource
+ extends StatefulProcessorWithInitialStateTestClass[UnionInitialStateRow] {
+ override def handleInitialState(
+ key: String, initialState: UnionInitialStateRow, timerValues:
TimerValues): Unit = {
+ if (initialState.value.isDefined) {
+ _valState.update(initialState.value.get.toDouble)
+ } else if (initialState.listValue.isDefined) {
+ _listState.appendValue(initialState.listValue.get.toDouble)
+ } else if (initialState.userMapKey.isDefined) {
+ _mapState.updateValue(
+ (initialState.userMapKey.get.charAt(0) - 'a').toDouble,
+ initialState.userMapValue.get.toInt)
+ }
+ }
+}
+
+/**
+ * Stateful processor that will take a union dataframe output from state data
source reader,
+ * with composite type state data source rows.
Review Comment:
ditto
##########
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala:
##########
@@ -497,4 +587,121 @@ class TransformWithStateInitialStateSuite extends
StateStoreMetricsTest
)
}
}
+
+ testInitialStateWithStateDataSource(true) { (valDf, listDf, mapDf,
inputData) =>
+ val valueDf = valDf.selectExpr("key.value AS groupingKey", "value.value AS
value")
+ val flattenListDf = listDf
+ .selectExpr("key.value AS groupingKey", "list_element.value AS
listValue")
+ val flattenMapDf = mapDf
+ .selectExpr(
+ "key.value AS groupingKey",
+ "user_map_key.value AS userMapKey",
+ "user_map_value.value AS userMapValue")
+ val df_joined =
+ valueDf.unionByName(flattenListDf, true)
+ .unionByName(flattenMapDf, true)
+ val kvDataSet = inputData.toDS().groupByKey(x => x.key)
+ val initDf = df_joined.as[UnionInitialStateRow].groupByKey(x =>
x.groupingKey)
+ kvDataSet.transformWithState(
+ new InitialStatefulProcessorWithStateDataSource(),
+ TimeMode.None(), OutputMode.Append(), initDf).toDF()
+ }
+
+ testInitialStateWithStateDataSource(false) { (valDf, listDf, mapDf,
inputData) =>
+ val valueDf = valDf.selectExpr("key.value AS groupingKey", "value.value AS
value")
+ val unflattenListDf = listDf
+ .selectExpr("key.value AS groupingKey",
+ "list_value.value as listValue")
+ val unflattenMapDf = mapDf
+ .selectExpr(
+ "key.value AS groupingKey",
+ "map_from_entries(transform(map_entries(map_value), x -> " +
+ "struct(x.key.value, x.value.value))) as mapValue")
+ val df_joined =
+ valueDf.unionByName(unflattenListDf, true)
+ .unionByName(unflattenMapDf, true)
+ val kvDataSet = inputData.toDS().groupByKey(x => x.key)
+ val initDf = df_joined.as[UnionUnflattenInitialStateRow].groupByKey(x =>
x.groupingKey)
+ kvDataSet.transformWithState(
+ new InitialStatefulProcessorWithUnflattenStateDataSource(),
+ TimeMode.None(), OutputMode.Append(), initDf).toDF()
+ }
+
+ private def testInitialStateWithStateDataSource(
+ flattenOption: Boolean)
+ (startQuery: (DataFrame, DataFrame, DataFrame,
+ MemoryStream[InitInputRow]) => DataFrame): Unit = {
+ Seq(("5", "2"), ("5", "8"), ("5", "5")).foreach { partitions =>
+ test("transformWithStateWithInitialState - state data source reader
dataframe " +
+ s"as initial state with flatten option set to $flattenOption, the
first stream and " +
+ s"the second stream is running on shuffle partition number of
${partitions._1} and " +
+ s"${partitions._2} respectively.") {
+ withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+ classOf[RocksDBStateStoreProvider].getName) {
+ withTempDir { checkpointDir =>
+ SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key,
partitions._1)
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new StatefulProcessorWithAllStateVars(),
+ TimeMode.None(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = checkpointDir.getCanonicalPath),
+ AddData(inputData, "a", "b"),
+ CheckNewAnswer(("a", "1"), ("b", "1")),
+ AddData(inputData, "a", "b", "a"),
+ CheckNewAnswer(("a", "3"), ("b", "2"))
+ )
+
+ // state data source reader for state vars
+ val valueDf = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, checkpointDir.getAbsolutePath)
+ .option(StateSourceOptions.STATE_VAR_NAME, "countState")
+ .load()
+
+ val listDf = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, checkpointDir.getAbsolutePath)
+ .option(StateSourceOptions.STATE_VAR_NAME, "listState")
+ .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES,
flattenOption)
+ .load()
+
+ val mapDf = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, checkpointDir.getAbsolutePath)
+ .option(StateSourceOptions.STATE_VAR_NAME, "mapState")
+ .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES,
flattenOption)
+ .load()
+
+ // create a df where each row contains all value, list, map state
rows
Review Comment:
If someone has to verify the test to work correctly, they have to do the
following
1) learn how StatefulProcessorWithAllStateVars works
2) go through the input rows and calculate by themselves how the values of
three state vars would look like
3) look at how the second query is constructed with initial state from state
data source reader
4) learn how InitialStatefulProcessorWithStateDataSource works
5) reason about the value of state variables for initial state
6) go through the input rows and calculate by themselves how the values of
three state vars would look like
I understand the purpose, making the test to be more e2e usage like, but
this is a unit test, and what we describe for test name is about `ability of
setting up initial state with the result of state data source reader`. I don't
expect the test to show up the flexibility of what users can do. That should be
placed to the examples directory or some blog post with github repo, not here.
I think the test still verifies what it intends to do, if the test just
shows the ability of migration for checkpoint "with the same logic of the
query". The test doesn't need for reader to understand two different queries;
if the test runs for a couple batches, and can continue with another checkpoint
(meaning batch 0) with retaining state via initial state (and allow further
modification), that should be enough.
##########
sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala:
##########
@@ -121,6 +121,8 @@ private[sql] abstract class
StatefulProcessorWithInitialState[K, I, O, S]
/**
* Function that will be invoked only in the first batch for users to
process initial states.
+ * Allow multiple initial state rows mapping to the same grouping key to
support integration
Review Comment:
Please make the method doc to be understandable from users for user-facing
API. You can leave the comment for the rationale of allowing this separately,
like `// ` outside of the method doc, to not include that part to the method
doc.
Please pretend yourself to be a one of users who is just a "starter" for the
method doc of user facing API.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]