This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5d1d44ff1201 [SPARK-49467][SS] Add support for state data source
reader and list state
5d1d44ff1201 is described below
commit 5d1d44ff1201af562c87ed2898d67f04e3292683
Author: Anish Shrigondekar <[email protected]>
AuthorDate: Fri Sep 6 17:01:52 2024 +0900
[SPARK-49467][SS] Add support for state data source reader and list state
### What changes were proposed in this pull request?
Add support for state data source reader and list state
### Why are the changes needed?
This change adds support for reading state written using list state used
primarily within the stateful processor used with the `transformWithState`
operator
### Does this PR introduce _any_ user-facing change?
Yes
Users can read state and `explode` entries using the following query:
```
val stateReaderDf = spark.read
.format("statestore")
.option(StateSourceOptions.PATH, <checkpoint_location>)
.option(StateSourceOptions.STATE_VAR_NAME, <state_var_name>)
.load()
val listStateDf = stateReaderDf
.selectExpr(
"key.value AS groupingKey",
"list_value AS valueList",
"partition_id")
.select($"groupingKey",
explode($"valueList").as("valueList"))
```
### How was this patch tested?
Added unit tests
```
[info] Run completed in 1 minute, 3 seconds.
[info] Total number of tests run: 8
[info] Suites: completed 1, aborted 0
[info] Tests: succeeded 8, failed 0, canceled 0, ignored 0, pending 0
[info] All tests passed.
```
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #47978 from anishshri-db/task/SPARK-49467.
Authored-by: Anish Shrigondekar <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
.../v2/state/StatePartitionReader.scala | 46 ++++--
.../datasources/v2/state/utils/SchemaUtil.scala | 55 +++----
.../StateDataSourceTransformWithStateSuite.scala | 161 ++++++++++++++++++++-
3 files changed, 216 insertions(+), 46 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
index 53576c335cb0..1af2ec174c66 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
@@ -19,12 +19,13 @@ package org.apache.spark.sql.execution.datasources.v2.state
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow,
UnsafeRow}
+import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader,
PartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.{StateVariableType,
TransformWithStateVariableInfo}
import org.apache.spark.sql.execution.streaming.state._
import
org.apache.spark.sql.execution.streaming.state.RecordType.{getRecordTypeAsString,
RecordType}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{NullType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.{NextIterator, SerializableConfiguration}
@@ -68,10 +69,20 @@ abstract class StatePartitionReaderBase(
stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema])
extends PartitionReader[InternalRow] with Logging {
+ // Used primarily as a placeholder for the value schema in the context of
+ // state variables used within the transformWithState operator.
+ private val schemaForValueRow: StructType =
+ StructType(Array(StructField("__dummy__", NullType)))
+
protected val keySchema = SchemaUtil.getSchemaAsDataType(
schema, "key").asInstanceOf[StructType]
- protected val valueSchema = SchemaUtil.getSchemaAsDataType(
- schema, "value").asInstanceOf[StructType]
+
+ protected val valueSchema = if (stateVariableInfoOpt.isDefined) {
+ schemaForValueRow
+ } else {
+ SchemaUtil.getSchemaAsDataType(
+ schema, "value").asInstanceOf[StructType]
+ }
protected lazy val provider: StateStoreProvider = {
val stateStoreId =
StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
@@ -84,10 +95,17 @@ abstract class StatePartitionReaderBase(
false
}
+ val useMultipleValuesPerKey = if (stateVariableInfoOpt.isDefined &&
+ stateVariableInfoOpt.get.stateVariableType ==
StateVariableType.ListState) {
+ true
+ } else {
+ false
+ }
+
val provider = StateStoreProvider.createAndInit(
stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec,
useColumnFamilies = useColFamilies, storeConf, hadoopConf.value,
- useMultipleValuesPerKey = false)
+ useMultipleValuesPerKey = useMultipleValuesPerKey)
if (useColFamilies) {
val store = provider.getStore(partition.sourceOptions.batchId + 1)
@@ -99,7 +117,7 @@ abstract class StatePartitionReaderBase(
stateStoreColFamilySchema.keySchema,
stateStoreColFamilySchema.valueSchema,
stateStoreColFamilySchema.keyStateEncoderSpec.get,
- useMultipleValuesPerKey = false)
+ useMultipleValuesPerKey = useMultipleValuesPerKey)
}
provider
}
@@ -166,16 +184,22 @@ class StatePartitionReader(
stateVariableInfoOpt match {
case Some(stateVarInfo) =>
val stateVarType = stateVarInfo.stateVariableType
- val hasTTLEnabled = stateVarInfo.ttlEnabled
stateVarType match {
case StateVariableType.ValueState =>
- if (hasTTLEnabled) {
- SchemaUtil.unifyStateRowPairWithTTL((pair.key, pair.value),
valueSchema,
- partition.partition)
- } else {
- SchemaUtil.unifyStateRowPair((pair.key, pair.value),
partition.partition)
+ SchemaUtil.unifyStateRowPair((pair.key, pair.value),
partition.partition)
+
+ case StateVariableType.ListState =>
+ val key = pair.key
+ val result = store.valuesIterator(key, stateVarName)
+ var unsafeRowArr: Seq[UnsafeRow] = Seq.empty
+ result.foreach { entry =>
+ unsafeRowArr = unsafeRowArr :+ entry.copy()
}
+ // convert the list of values to array type
+ val arrData = new GenericArrayData(unsafeRowArr.toArray)
+ SchemaUtil.unifyStateRowPairWithMultipleValues((pair.key,
arrData),
+ partition.partition)
case _ =>
throw new IllegalStateException(
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
index 9dd357530ec4..47bf9250000a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
@@ -19,10 +19,11 @@ package
org.apache.spark.sql.execution.datasources.v2.state.utils
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow,
UnsafeRow}
+import org.apache.spark.sql.catalyst.util.GenericArrayData
import
org.apache.spark.sql.execution.datasources.v2.state.{StateDataSourceErrors,
StateSourceOptions}
import org.apache.spark.sql.execution.streaming.{StateVariableType,
TransformWithStateVariableInfo}
import org.apache.spark.sql.execution.streaming.state.StateStoreColFamilySchema
-import org.apache.spark.sql.types.{DataType, IntegerType, LongType,
StringType, StructType}
+import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, LongType,
StringType, StructType}
import org.apache.spark.util.ArrayImplicits._
object SchemaUtil {
@@ -70,15 +71,13 @@ object SchemaUtil {
row
}
- def unifyStateRowPairWithTTL(
- pair: (UnsafeRow, UnsafeRow),
- valueSchema: StructType,
+ def unifyStateRowPairWithMultipleValues(
+ pair: (UnsafeRow, GenericArrayData),
partition: Int): InternalRow = {
- val row = new GenericInternalRow(4)
+ val row = new GenericInternalRow(3)
row.update(0, pair._1)
- row.update(1, pair._2.get(0, valueSchema))
- row.update(2, pair._2.get(1, LongType))
- row.update(3, partition)
+ row.update(1, pair._2)
+ row.update(2, partition)
row
}
@@ -91,23 +90,22 @@ object SchemaUtil {
"change_type" -> classOf[StringType],
"key" -> classOf[StructType],
"value" -> classOf[StructType],
- "partition_id" -> classOf[IntegerType],
- "expiration_timestamp" -> classOf[LongType])
+ "single_value" -> classOf[StructType],
+ "list_value" -> classOf[ArrayType],
+ "partition_id" -> classOf[IntegerType])
val expectedFieldNames = if (sourceOptions.readChangeFeed) {
Seq("batch_id", "change_type", "key", "value", "partition_id")
} else if (transformWithStateVariableInfoOpt.isDefined) {
val stateVarInfo = transformWithStateVariableInfoOpt.get
- val hasTTLEnabled = stateVarInfo.ttlEnabled
val stateVarType = stateVarInfo.stateVariableType
stateVarType match {
case StateVariableType.ValueState =>
- if (hasTTLEnabled) {
- Seq("key", "value", "expiration_timestamp", "partition_id")
- } else {
- Seq("key", "value", "partition_id")
- }
+ Seq("key", "single_value", "partition_id")
+
+ case StateVariableType.ListState =>
+ Seq("key", "list_value", "partition_id")
case _ =>
throw StateDataSourceErrors
@@ -131,24 +129,19 @@ object SchemaUtil {
stateVarInfo: TransformWithStateVariableInfo,
stateStoreColFamilySchema: StateStoreColFamilySchema): StructType = {
val stateVarType = stateVarInfo.stateVariableType
- val hasTTLEnabled = stateVarInfo.ttlEnabled
stateVarType match {
case StateVariableType.ValueState =>
- if (hasTTLEnabled) {
- val ttlValueSchema = SchemaUtil.getSchemaAsDataType(
- stateStoreColFamilySchema.valueSchema,
"value").asInstanceOf[StructType]
- new StructType()
- .add("key", stateStoreColFamilySchema.keySchema)
- .add("value", ttlValueSchema)
- .add("expiration_timestamp", LongType)
- .add("partition_id", IntegerType)
- } else {
- new StructType()
- .add("key", stateStoreColFamilySchema.keySchema)
- .add("value", stateStoreColFamilySchema.valueSchema)
- .add("partition_id", IntegerType)
- }
+ new StructType()
+ .add("key", stateStoreColFamilySchema.keySchema)
+ .add("single_value", stateStoreColFamilySchema.valueSchema)
+ .add("partition_id", IntegerType)
+
+ case StateVariableType.ListState =>
+ new StructType()
+ .add("key", stateStoreColFamilySchema.keySchema)
+ .add("list_value", ArrayType(stateStoreColFamilySchema.valueSchema))
+ .add("partition_id", IntegerType)
case _ =>
throw StateDataSourceErrors.internalError(s"Unsupported state variable
type $stateVarType")
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
index ccd4e005756a..1c06e4f97f2b 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
@@ -21,8 +21,9 @@ import java.time.Duration
import org.apache.spark.sql.{Encoders, Row}
import org.apache.spark.sql.execution.streaming.MemoryStream
import
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
RocksDBStateStoreProvider, TestClass}
+import org.apache.spark.sql.functions.explode
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.streaming.{ExpiredTimerInfo, OutputMode,
RunningCountStatefulProcessor, StatefulProcessor, StateStoreMetricsTest,
TimeMode, TimerValues, TransformWithStateSuiteUtils, TTLConfig, ValueState}
+import org.apache.spark.sql.streaming.{ExpiredTimerInfo, ListState,
OutputMode, RunningCountStatefulProcessor, StatefulProcessor,
StateStoreMetricsTest, TimeMode, TimerValues, TransformWithStateSuiteUtils,
TTLConfig, ValueState}
/** Stateful processor of single value state var with non-primitive type */
class StatefulProcessorWithSingleValueVar extends
RunningCountStatefulProcessor {
@@ -73,6 +74,52 @@ class StatefulProcessorWithTTL
}
}
+/** Stateful processor tracking groups belonging to sessions with/without TTL
*/
+class SessionGroupsStatefulProcessor extends
+ StatefulProcessor[String, (String, String), String] {
+ @transient private var _groupsList: ListState[String] = _
+
+ override def init(
+ outputMode: OutputMode,
+ timeMode: TimeMode): Unit = {
+ _groupsList = getHandle.getListState("groupsList", Encoders.STRING)
+ }
+
+ override def handleInputRows(
+ key: String,
+ inputRows: Iterator[(String, String)],
+ timerValues: TimerValues,
+ expiredTimerInfo: ExpiredTimerInfo): Iterator[String] = {
+ inputRows.foreach { inputRow =>
+ _groupsList.appendValue(inputRow._2)
+ }
+ Iterator.empty
+ }
+}
+
+class SessionGroupsStatefulProcessorWithTTL extends
+ StatefulProcessor[String, (String, String), String] {
+ @transient private var _groupsListWithTTL: ListState[String] = _
+
+ override def init(
+ outputMode: OutputMode,
+ timeMode: TimeMode): Unit = {
+ _groupsListWithTTL = getHandle.getListState("groupsListWithTTL",
Encoders.STRING,
+ TTLConfig(Duration.ofMillis(30000)))
+ }
+
+ override def handleInputRows(
+ key: String,
+ inputRows: Iterator[(String, String)],
+ timerValues: TimerValues,
+ expiredTimerInfo: ExpiredTimerInfo): Iterator[String] = {
+ inputRows.foreach { inputRow =>
+ _groupsListWithTTL.appendValue(inputRow._2)
+ }
+ Iterator.empty
+ }
+}
+
/**
* Test suite to verify integration of state data source reader with the
transformWithState operator
*/
@@ -111,7 +158,7 @@ class StateDataSourceTransformWithStateSuite extends
StateStoreMetricsTest
val resultDf = stateReaderDf.selectExpr(
"key.value AS groupingKey",
- "value.id AS valueId", "value.name AS valueName",
+ "single_value.id AS valueId", "single_value.name AS valueName",
"partition_id")
checkAnswer(resultDf,
@@ -174,7 +221,7 @@ class StateDataSourceTransformWithStateSuite extends
StateStoreMetricsTest
.load()
val resultDf = stateReaderDf.selectExpr(
- "key.value", "value.value", "expiration_timestamp", "partition_id")
+ "key.value", "single_value.value", "single_value.ttlExpirationMs",
"partition_id")
var count = 0L
resultDf.collect().foreach { row =>
@@ -187,7 +234,7 @@ class StateDataSourceTransformWithStateSuite extends
StateStoreMetricsTest
val answerDf = stateReaderDf.selectExpr(
"key.value AS groupingKey",
- "value.value AS valueId", "partition_id")
+ "single_value.value.value AS valueId", "partition_id")
checkAnswer(answerDf,
Seq(Row("a", 1L, 0), Row("b", 1L, 1)))
@@ -217,4 +264,110 @@ class StateDataSourceTransformWithStateSuite extends
StateStoreMetricsTest
}
}
}
+
+ test("state data source integration - list state") {
+ withTempDir { tempDir =>
+ withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+ classOf[RocksDBStateStoreProvider].getName) {
+
+ val inputData = MemoryStream[(String, String)]
+ val result = inputData.toDS()
+ .groupByKey(x => x._1)
+ .transformWithState(new SessionGroupsStatefulProcessor(),
+ TimeMode.None(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath),
+ AddData(inputData, ("session1", "group2")),
+ AddData(inputData, ("session1", "group1")),
+ AddData(inputData, ("session2", "group1")),
+ CheckNewAnswer(),
+ AddData(inputData, ("session3", "group7")),
+ AddData(inputData, ("session1", "group4")),
+ CheckNewAnswer(),
+ StopStream
+ )
+
+ val stateReaderDf = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+ .option(StateSourceOptions.STATE_VAR_NAME, "groupsList")
+ .load()
+
+ val listStateDf = stateReaderDf
+ .selectExpr(
+ "key.value AS groupingKey",
+ "list_value.value AS valueList",
+ "partition_id")
+ .select($"groupingKey",
+ explode($"valueList"))
+
+ checkAnswer(listStateDf,
+ Seq(Row("session1", "group1"), Row("session1", "group2"),
Row("session1", "group4"),
+ Row("session2", "group1"), Row("session3", "group7")))
+ }
+ }
+ }
+
+ test("state data source integration - list state and TTL") {
+ withTempDir { tempDir =>
+ withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+ classOf[RocksDBStateStoreProvider].getName,
+ SQLConf.SHUFFLE_PARTITIONS.key ->
+ TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {
+ val inputData = MemoryStream[(String, String)]
+ val result = inputData.toDS()
+ .groupByKey(x => x._1)
+ .transformWithState(new SessionGroupsStatefulProcessorWithTTL(),
+ TimeMode.ProcessingTime(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath),
+ AddData(inputData, ("session1", "group2")),
+ AddData(inputData, ("session1", "group1")),
+ AddData(inputData, ("session2", "group1")),
+ AddData(inputData, ("session3", "group7")),
+ AddData(inputData, ("session1", "group4")),
+ Execute { _ =>
+ // wait for the batch to run since we are using processing time
+ Thread.sleep(5000)
+ },
+ StopStream
+ )
+
+ val stateReaderDf = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+ .option(StateSourceOptions.STATE_VAR_NAME, "groupsListWithTTL")
+ .load()
+
+ val listStateDf = stateReaderDf
+ .selectExpr(
+ "key.value AS groupingKey",
+ "list_value AS valueList",
+ "partition_id")
+ .select($"groupingKey",
+ explode($"valueList").as("valueList"))
+
+ val resultDf = listStateDf.selectExpr("valueList.ttlExpirationMs")
+ var count = 0L
+ resultDf.collect().foreach { row =>
+ count = count + 1
+ assert(row.getLong(0) > 0)
+ }
+
+ // verify that 5 state rows are present
+ assert(count === 5)
+
+ val valuesDf = listStateDf.selectExpr("groupingKey",
+ "valueList.value.value AS groupId")
+
+ checkAnswer(valuesDf,
+ Seq(Row("session1", "group1"), Row("session1", "group2"),
Row("session1", "group4"),
+ Row("session2", "group1"), Row("session3", "group7")))
+ }
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]