This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5432cef483c4 [SPARK-50152][SS] Support handleInitialState with state
data source reader
5432cef483c4 is described below
commit 5432cef483c499e3548a30d608550bac9fce53ec
Author: jingz-db <[email protected]>
AuthorDate: Tue Nov 12 10:43:48 2024 +0900
[SPARK-50152][SS] Support handleInitialState with state data source reader
### What changes were proposed in this pull request?
This PR adds support for users to provide a Dataframe that can be used to
instantiate state for the query in the first batch for arbitrary state API v2.
More specifically, this dataframe is coming from state data source reader.
Remove the restraints that initialState dataframe can only contains one
value row for a grouping key. This is to enable the integration with state data
source reader. In flattened state data source reader for composite type, we
will have multiple value rows mapping to the same grouping key.
For example, we can union dataframe created by state data source reader on
a single state variable and union them together and get an output dataframe as
initial state for a transformWithState operator like this:
```
+-----------+-----+---------+----------+------------+
|groupingKey|value|listValue|userMapKey|userMapValue|
+-----------+-----+---------+----------+------------+
|a |3 |NULL |NULL |NULL |
|b |2 |NULL |NULL |NULL |
|a |NULL |1 |NULL |NULL |
|a |NULL |2 |NULL |NULL |
|a |NULL |3 |NULL |NULL |
|b |NULL |1 |NULL |NULL |
|b |NULL |2 |NULL |NULL |
|a |NULL |NULL |a |3 |
|b |NULL |NULL |b |2 |
+-----------+-----+---------+----------+------------+
```
### Why are the changes needed?
This change is for supporting initial state handling for integration with
state data source reader.
### Does this PR introduce _any_ user-facing change?
No. The user API is the same as prior PR:
https://github.com/apache/spark/pull/45467 for initial state support without
state data source reader.
### How was this patch tested?
Unit test cases added in `TransformWithStateWithInitialStateSuite`.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #48686 from jingz-db/initial-state-reader-integration.
Lead-authored-by: jingz-db <[email protected]>
Co-authored-by: Jungtaek Lim <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 6 -
.../spark/sql/streaming/StatefulProcessor.scala | 3 +
.../streaming/TransformWithStateExec.scala | 8 +-
.../streaming/state/StateStoreErrors.scala | 10 -
.../TransformWithStateInitialStateSuite.scala | 286 ++++++++++++++++++---
5 files changed, 257 insertions(+), 56 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index 7ef6feae0845..154fee2eefb7 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -4233,12 +4233,6 @@
],
"sqlState" : "42802"
},
- "STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY" : {
- "message" : [
- "Cannot re-initialize state on the same grouping key during initial
state handling for stateful processor. Invalid grouping key=<groupingKey>."
- ],
- "sqlState" : "42802"
- },
"STATEFUL_PROCESSOR_DUPLICATE_STATE_VARIABLE_DEFINED" : {
"message" : [
"State variable with name <stateVarName> has already been defined in the
StatefulProcessor."
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala
index 719d1e572c20..55477b4dda0c 100644
---
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala
+++
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala
@@ -121,6 +121,9 @@ private[sql] abstract class
StatefulProcessorWithInitialState[K, I, O, S]
/**
* Function that will be invoked only in the first batch for users to
process initial states.
+ * The provided initial state can be arbitrary dataframe with the same
grouping key schema with
+ * the input rows, e.g. dataframe from data source reader of existing
streaming query
+ * checkpoint.
*
* @param key
* \- grouping key
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
index 4f7a10f88245..2b26d18019d1 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
@@ -271,13 +271,9 @@ case class TransformWithStateExec(
ImplicitGroupingKeyTracker.setImplicitKey(keyObj)
val initStateObjIter = initStateIter.map(getInitStateValueObj.apply)
- var seenInitStateOnKey = false
initStateObjIter.foreach { initState =>
- // cannot re-initialize state on the same grouping key during initial
state handling
- if (seenInitStateOnKey) {
- throw StateStoreErrors.cannotReInitializeStateOnKey(keyObj.toString)
- }
- seenInitStateOnKey = true
+ // allow multiple initial state rows on the same grouping key for
integration
+ // with state data source reader with initial state
statefulProcessor
.asInstanceOf[StatefulProcessorWithInitialState[Any, Any, Any, Any]]
.handleInitialState(keyObj, initState,
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
index e4b370e67b01..45ad7e14c52d 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
@@ -123,11 +123,6 @@ object StateStoreErrors {
new
StatefulProcessorCannotPerformOperationWithInvalidHandleState(operationType,
handleState)
}
- def cannotReInitializeStateOnKey(groupingKey: String):
- StatefulProcessorCannotReInitializeState = {
- new StatefulProcessorCannotReInitializeState(groupingKey)
- }
-
def cannotProvideTTLConfigForTimeMode(stateName: String, timeMode: String):
StatefulProcessorCannotAssignTTLInTimeMode = {
new StatefulProcessorCannotAssignTTLInTimeMode(stateName, timeMode)
@@ -272,11 +267,6 @@ class
StatefulProcessorCannotPerformOperationWithInvalidHandleState(
messageParameters = Map("operationType" -> operationType, "handleState" ->
handleState)
)
-class StatefulProcessorCannotReInitializeState(groupingKey: String)
- extends SparkUnsupportedOperationException(
- errorClass = "STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY",
- messageParameters = Map("groupingKey" -> groupingKey))
-
class StateStoreUnsupportedOperationOnMissingColumnFamily(
operationType: String,
colFamilyName: String) extends SparkUnsupportedOperationException(
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
index 35ac8a4687eb..360656a76f35 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
@@ -17,11 +17,11 @@
package org.apache.spark.sql.streaming
-import org.apache.spark.SparkUnsupportedOperationException
-import org.apache.spark.sql.{Dataset, Encoders, KeyValueGroupedDataset}
+import org.apache.spark.sql.{DataFrame, Dataset, Encoders,
KeyValueGroupedDataset}
+import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions
import org.apache.spark.sql.execution.streaming.MemoryStream
import
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
RocksDBStateStoreProvider}
-import org.apache.spark.sql.functions.timestamp_seconds
+import org.apache.spark.sql.functions.{col, timestamp_seconds}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.util.StreamManualClock
@@ -29,6 +29,21 @@ case class InitInputRow(key: String, action: String, value:
Double)
case class InputRowForInitialState(
key: String, value: Double, entries: List[Double], mapping: Map[Double,
Int])
+case class UnionInitialStateRow(
+ groupingKey: String,
+ value: Option[Long],
+ listValue: Option[Long],
+ userMapKey: Option[String],
+ userMapValue: Option[Long]
+)
+
+case class UnionUnflattenInitialStateRow(
+ groupingKey: String,
+ value: Option[Long],
+ listValue: Option[Seq[Long]],
+ mapValue: Option[Map[String, Long]]
+)
+
abstract class StatefulProcessorWithInitialStateTestClass[V]
extends StatefulProcessorWithInitialState[
String, InitInputRow, (String, String, Double), V] {
@@ -86,6 +101,86 @@ abstract class StatefulProcessorWithInitialStateTestClass[V]
}
}
+/**
+ * Class that updates all state variables with input rows. Act as a counter -
+ * keep the count in value state; keep all the occurrences in list state; and
+ * keep a map of key -> occurrence count in the map state.
+ */
+abstract class InitialStateWithStateDataSourceBase[V]
+ extends StatefulProcessorWithInitialState[
+ String, String, (String, String), V] {
+ @transient var _valState: ValueState[Long] = _
+ @transient var _listState: ListState[Long] = _
+ @transient var _mapState: MapState[String, Long] = _
+
+ override def init(
+ outputMode: OutputMode,
+ timeMode: TimeMode): Unit = {
+ _valState = getHandle.getValueState[Long]("testVal", Encoders.scalaLong,
TTLConfig.NONE)
+ _listState = getHandle.getListState[Long]("testList", Encoders.scalaLong,
TTLConfig.NONE)
+ _mapState = getHandle.getMapState[String, Long](
+ "testMap", Encoders.STRING, Encoders.scalaLong, TTLConfig.NONE)
+ }
+
+ override def handleInputRows(
+ key: String,
+ inputRows: Iterator[String],
+ timerValues: TimerValues): Iterator[(String, String)] = {
+ val curCountValue = if (_valState.exists()) {
+ _valState.get()
+ } else {
+ 0L
+ }
+ var cnt = curCountValue
+ inputRows.foreach { row =>
+ cnt += 1
+ _listState.appendValue(cnt)
+ val mapCurVal = if (_mapState.containsKey(row)) {
+ _mapState.getValue(row)
+ } else {
+ 0
+ }
+ _mapState.updateValue(row, mapCurVal + 1L)
+ }
+ _valState.update(cnt)
+ Iterator.single((key, cnt.toString))
+ }
+
+ override def close(): Unit = super.close()
+}
+
+class InitialStatefulProcessorWithStateDataSource
+ extends InitialStateWithStateDataSourceBase[UnionInitialStateRow] {
+ override def handleInitialState(
+ key: String, initialState: UnionInitialStateRow, timerValues:
TimerValues): Unit = {
+ if (initialState.value.isDefined) {
+ _valState.update(initialState.value.get)
+ } else if (initialState.listValue.isDefined) {
+ _listState.appendValue(initialState.listValue.get)
+ } else if (initialState.userMapKey.isDefined) {
+ _mapState.updateValue(
+ initialState.userMapKey.get, initialState.userMapValue.get)
+ }
+ }
+}
+
+class InitialStatefulProcessorWithUnflattenStateDataSource
+ extends InitialStateWithStateDataSourceBase[UnionUnflattenInitialStateRow] {
+ override def handleInitialState(
+ key: String, initialState: UnionUnflattenInitialStateRow, timerValues:
TimerValues): Unit = {
+ if (initialState.value.isDefined) {
+ _valState.update(initialState.value.get)
+ } else if (initialState.listValue.isDefined) {
+ _listState.appendList(
+ initialState.listValue.get.toArray)
+ } else if (initialState.mapValue.isDefined) {
+ initialState.mapValue.get.keys.foreach { key =>
+ _mapState.updateValue(key, initialState.mapValue.get.get(key).get)
+ }
+ }
+ }
+}
+
class AccumulateStatefulProcessorWithInitState
extends StatefulProcessorWithInitialStateTestClass[(String, Double)] {
override def handleInitialState(
@@ -398,37 +493,6 @@ class TransformWithStateInitialStateSuite extends
StateStoreMetricsTest
checkAnswer(df, Seq(("k1", "getOption", 37.0)).toDF())
}
- test("transformWithStateWithInitialState - " +
- "cannot re-initialize state during initial state handling") {
- withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
- classOf[RocksDBStateStoreProvider].getName) {
- val initDf = Seq(("init_1", 40.0), ("init_2", 100.0), ("init_1",
50.0)).toDS()
- .groupByKey(x => x._1).mapValues(x => x)
- val inputData = MemoryStream[InitInputRow]
- val query = inputData.toDS()
- .groupByKey(x => x.key)
- .transformWithState(new AccumulateStatefulProcessorWithInitState(),
- TimeMode.None(),
- OutputMode.Append(),
- initDf)
-
- testStream(query, OutputMode.Update())(
- AddData(inputData, InitInputRow("k1", "add", 50.0)),
- Execute { q =>
- val e = intercept[Exception] {
- q.processAllAvailable()
- }
- checkError(
- exception =
e.getCause.asInstanceOf[SparkUnsupportedOperationException],
- condition = "STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY",
- sqlState = Some("42802"),
- parameters = Map("groupingKey" -> "init_1")
- )
- }
- )
- }
- }
-
test("transformWithStateWithInitialState - streaming with processing time
timer, " +
"can emit expired initial state rows when grouping key is not received for
new input rows") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
@@ -503,4 +567,158 @@ class TransformWithStateInitialStateSuite extends
StateStoreMetricsTest
)
}
}
+
+ Seq(true, false).foreach { flattenOption =>
+ Seq(("5", "2"), ("5", "8"), ("5", "5")).foreach { partitions =>
+ test("state data source reader dataframe as initial state " +
+ s"(flatten option=$flattenOption, shuffle partition for 1st
stream=${partitions._1}, " +
+ s"shuffle partition for 1st stream=${partitions._2})") {
+ withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+ classOf[RocksDBStateStoreProvider].getName) {
+ withTempPaths(2) { checkpointDirs =>
+ SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key,
partitions._1)
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new
InitialStatefulProcessorWithStateDataSource(),
+ TimeMode.None(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation =
checkpointDirs(0).getCanonicalPath),
+ AddData(inputData, "a", "b"),
+ CheckNewAnswer(("a", "1"), ("b", "1")),
+ AddData(inputData, "a", "b", "a"),
+ CheckNewAnswer(("a", "3"), ("b", "2"))
+ )
+
+ // We are trying to mimic a use case where users load all state
data rows
+ // from a previous tws query as initial state and start a new tws
query.
+ // In this use case, users will need to create a single dataframe
with
+ // all the state rows from different state variables with
different schema.
+ // We can only read from one state variable from one state data
source reader
+ // query, and they are of different schema. We will get one
dataframe from each
+ // state variable, and we union them together into a single
dataframe.
+ val valueDf = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH,
checkpointDirs(0).getAbsolutePath)
+ .option(StateSourceOptions.STATE_VAR_NAME, "testVal")
+ .load()
+ .drop("partition_id")
+
+ val listDf = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH,
checkpointDirs(0).getAbsolutePath)
+ .option(StateSourceOptions.STATE_VAR_NAME, "testList")
+ .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES,
flattenOption)
+ .load()
+ .drop("partition_id")
+
+ val mapDf = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH,
checkpointDirs(0).getAbsolutePath)
+ .option(StateSourceOptions.STATE_VAR_NAME, "testMap")
+ .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES,
flattenOption)
+ .load()
+ .drop("partition_id")
+
+ // create a df where each row contains all value, list, map state
rows
+ // fill the missing column with null.
+ SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key,
partitions._2)
+ val inputData2 = MemoryStream[String]
+ val query = startQueryWithDataSourceDataframeAsInitState(
+ flattenOption, valueDf, listDf, mapDf, inputData2)
+
+ testStream(query, OutputMode.Update())(
+ StartStream(checkpointLocation =
checkpointDirs(1).getCanonicalPath),
+ // check initial state is updated for state vars
+ AddData(inputData2, "c"),
+ CheckNewAnswer(("c", "1")),
+ Execute { _ =>
+ val valueDf2 = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH,
checkpointDirs(1).getAbsolutePath)
+ .option(StateSourceOptions.STATE_VAR_NAME, "testVal")
+ .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES,
flattenOption)
+ .load()
+ .drop("partition_id")
+ .filter(col("key.value") =!= "c")
+
+ val listDf2 = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH,
checkpointDirs(1).getAbsolutePath)
+ .option(StateSourceOptions.STATE_VAR_NAME, "testList")
+ .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES,
flattenOption)
+ .load()
+ .drop("partition_id")
+ .filter(col("key.value") =!= "c")
+
+ val mapDf2 = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH,
checkpointDirs(1).getAbsolutePath)
+ .option(StateSourceOptions.STATE_VAR_NAME, "testMap")
+ .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES,
flattenOption)
+ .load()
+ .drop("partition_id")
+ .filter(col("key.value") =!= "c")
+
+ checkAnswer(valueDf, valueDf2)
+ checkAnswer(listDf, listDf2)
+ checkAnswer(mapDf, mapDf2)
+ }
+ )
+ }
+ }
+ }
+ }
+ }
+
+ private def startQueryWithDataSourceDataframeAsInitState(
+ flattenOption: Boolean,
+ valDf: DataFrame,
+ listDf: DataFrame,
+ mapDf: DataFrame,
+ inputData: MemoryStream[String]): DataFrame = {
+ if (flattenOption) {
+ // when we read the state rows with flattened option set to true, values
of a composite
+ // state variable will be flattened into multiple rows where each row is
a
+ // key -> single value pair
+ val valueDf = valDf.selectExpr("key.value AS groupingKey", "value.value
AS value")
+ val flattenListDf = listDf
+ .selectExpr("key.value AS groupingKey", "list_element.value AS
listValue")
+ val flattenMapDf = mapDf
+ .selectExpr(
+ "key.value AS groupingKey",
+ "user_map_key.value AS userMapKey",
+ "user_map_value.value AS userMapValue")
+ val df_joined =
+ valueDf.unionByName(flattenListDf, true)
+ .unionByName(flattenMapDf, true)
+ val kvDataSet = inputData.toDS().groupByKey(x => x)
+ val initDf = df_joined.as[UnionInitialStateRow].groupByKey(x =>
x.groupingKey)
+ (kvDataSet.transformWithState(
+ new InitialStatefulProcessorWithStateDataSource(),
+ TimeMode.None(), OutputMode.Append(), initDf).toDF())
+ } else {
+ // when we read the state rows with flattened option set to false,
values of a composite
+ // state variable will be composed into a single row of list/map type
+ val valueDf = valDf.selectExpr("key.value AS groupingKey", "value.value
AS value")
+ val unflattenListDf = listDf
+ .selectExpr("key.value AS groupingKey",
+ "list_value.value as listValue")
+ val unflattenMapDf = mapDf
+ .selectExpr(
+ "key.value AS groupingKey",
+ "map_from_entries(transform(map_entries(map_value), x -> " +
+ "struct(x.key.value, x.value.value))) as mapValue")
+ val df_joined =
+ valueDf.unionByName(unflattenListDf, true)
+ .unionByName(unflattenMapDf, true)
+ val kvDataSet = inputData.toDS().groupByKey(x => x)
+ val initDf = df_joined.as[UnionUnflattenInitialStateRow].groupByKey(x =>
x.groupingKey)
+ kvDataSet.transformWithState(
+ new InitialStatefulProcessorWithUnflattenStateDataSource(),
+ TimeMode.None(), OutputMode.Append(), initDf).toDF()
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]