This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new fe1cf3200223 [SPARK-49656][SS] Add support for state variables with
value state collection types and read change feed options
fe1cf3200223 is described below
commit fe1cf3200223c33ed4670bfa5924d5a4053c8ef9
Author: Anish Shrigondekar <[email protected]>
AuthorDate: Thu Sep 26 17:38:58 2024 +0900
[SPARK-49656][SS] Add support for state variables with value state
collection types and read change feed options
### What changes were proposed in this pull request?
Add support for state variables with value state collection types and read
change feed options
### Why are the changes needed?
Without this, we cannot support reading per key changes for state variables
used with stateful processors.
### Does this PR introduce _any_ user-facing change?
Yes
Users can now query value state variables with the following query:
```
val changeFeedDf = spark.read
.format("statestore")
.option(StateSourceOptions.PATH, <checkpoint_loc>)
.option(StateSourceOptions.STATE_VAR_NAME, <state_var_name>)
.option(StateSourceOptions.READ_CHANGE_FEED, true)
.option(StateSourceOptions.CHANGE_START_BATCH_ID, 0)
.load()
```
### How was this patch tested?
Added unit tests
```
[info] Run completed in 17 seconds, 318 milliseconds.
[info] Total number of tests run: 2
[info] Suites: completed 1, aborted 0
[info] Tests: succeeded 2, failed 0, canceled 0, ignored 0, pending 0
[info] All tests passed.
```
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #48148 from anishshri-db/task/SPARK-49656.
Authored-by: Anish Shrigondekar <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
.../datasources/v2/state/StateDataSource.scala | 10 +-
.../v2/state/StatePartitionReader.scala | 10 +-
.../state/HDFSBackedStateStoreProvider.scala | 10 +-
.../state/RocksDBStateStoreProvider.scala | 79 ++++++++++++---
.../sql/execution/streaming/state/StateStore.scala | 6 +-
.../streaming/state/StateStoreChangelog.scala | 11 ++-
.../StateDataSourceTransformWithStateSuite.scala | 107 +++++++++++++++++----
7 files changed, 190 insertions(+), 43 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
index 39bc4dd9fb9c..edddfbd6ccae 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
@@ -33,7 +33,7 @@ import
org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{J
import
org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues.JoinSideValues
import
org.apache.spark.sql.execution.datasources.v2.state.metadata.{StateMetadataPartitionReader,
StateMetadataTableEntry}
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
-import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog,
OffsetSeqMetadata, TimerStateUtils, TransformWithStateOperatorProperties,
TransformWithStateVariableInfo}
+import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog,
OffsetSeqMetadata, StateVariableType, TimerStateUtils,
TransformWithStateOperatorProperties, TransformWithStateVariableInfo}
import
org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DIR_NAME_COMMITS,
DIR_NAME_OFFSETS, DIR_NAME_STATE}
import
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide,
RightSide}
import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec,
NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec,
StateSchemaCompatibilityChecker, StateStore, StateStoreColFamilySchema,
StateStoreConf, StateStoreId, StateStoreProviderId}
@@ -170,13 +170,15 @@ class StateDataSource extends TableProvider with
DataSourceRegister with Logging
}
val stateVars = twsOperatorProperties.stateVariables
- if (stateVars.filter(stateVar => stateVar.stateName ==
stateVarName).size != 1) {
+ val stateVarInfo = stateVars.filter(stateVar => stateVar.stateName ==
stateVarName)
+ if (stateVarInfo.size != 1) {
throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME,
s"State variable $stateVarName is not defined for the
transformWithState operator.")
}
- // TODO: Support change feed and transformWithState together
- if (sourceOptions.readChangeFeed) {
+ // TODO: add support for list and map type
+ if (sourceOptions.readChangeFeed &&
+ stateVarInfo.head.stateVariableType != StateVariableType.ValueState) {
throw
StateDataSourceErrors.conflictOptions(Seq(StateSourceOptions.READ_CHANGE_FEED,
StateSourceOptions.STATE_VAR_NAME))
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
index d77d97f0057f..b925aee5b627 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
@@ -223,10 +223,18 @@ class StateStoreChangeDataPartitionReader(
throw StateStoreErrors.stateStoreProviderDoesNotSupportFineGrainedReplay(
provider.getClass.toString)
}
+
+ val colFamilyNameOpt = if (stateVariableInfoOpt.isDefined) {
+ Some(stateVariableInfoOpt.get.stateName)
+ } else {
+ None
+ }
+
provider.asInstanceOf[SupportsFineGrainedReplay]
.getStateStoreChangeDataReader(
partition.sourceOptions.readChangeFeedOptions.get.changeStartBatchId +
1,
- partition.sourceOptions.readChangeFeedOptions.get.changeEndBatchId + 1)
+ partition.sourceOptions.readChangeFeedOptions.get.changeEndBatchId + 1,
+ colFamilyNameOpt)
}
override lazy val iter: Iterator[InternalRow] = {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index d9f4443b7961..884b8aa3853c 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -991,8 +991,16 @@ private[sql] class HDFSBackedStateStoreProvider extends
StateStoreProvider with
result
}
- override def getStateStoreChangeDataReader(startVersion: Long, endVersion:
Long):
+ override def getStateStoreChangeDataReader(
+ startVersion: Long,
+ endVersion: Long,
+ colFamilyNameOpt: Option[String] = None):
StateStoreChangeDataReader = {
+ // Multiple column families are not supported with
HDFSBackedStateStoreProvider
+ if (colFamilyNameOpt.isDefined) {
+ throw StateStoreErrors.multipleColumnFamiliesNotSupported(providerName)
+ }
+
new HDFSBackedStateStoreChangeDataReader(fm, baseDir, startVersion,
endVersion,
CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec),
keySchema, valueSchema)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
index 85f80ce9eb1a..6ab634668bc2 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -498,7 +498,10 @@ private[sql] class RocksDBStateStoreProvider
}
}
- override def getStateStoreChangeDataReader(startVersion: Long, endVersion:
Long):
+ override def getStateStoreChangeDataReader(
+ startVersion: Long,
+ endVersion: Long,
+ colFamilyNameOpt: Option[String] = None):
StateStoreChangeDataReader = {
val statePath = stateStoreId.storeCheckpointLocation()
val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
@@ -508,7 +511,8 @@ private[sql] class RocksDBStateStoreProvider
startVersion,
endVersion,
CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec),
- keyValueEncoderMap)
+ keyValueEncoderMap,
+ colFamilyNameOpt)
}
/**
@@ -676,27 +680,70 @@ class RocksDBStateStoreChangeDataReader(
endVersion: Long,
compressionCodec: CompressionCodec,
keyValueEncoderMap:
- ConcurrentHashMap[String, (RocksDBKeyStateEncoder,
RocksDBValueStateEncoder)])
+ ConcurrentHashMap[String, (RocksDBKeyStateEncoder,
RocksDBValueStateEncoder)],
+ colFamilyNameOpt: Option[String] = None)
extends StateStoreChangeDataReader(
- fm, stateLocation, startVersion, endVersion, compressionCodec) {
+ fm, stateLocation, startVersion, endVersion, compressionCodec,
colFamilyNameOpt) {
override protected var changelogSuffix: String = "changelog"
+ private def getColFamilyIdBytes: Option[Array[Byte]] = {
+ if (colFamilyNameOpt.isDefined) {
+ val colFamilyName = colFamilyNameOpt.get
+ if (!keyValueEncoderMap.containsKey(colFamilyName)) {
+ throw new IllegalStateException(
+ s"Column family $colFamilyName not found in the key value encoder
map")
+ }
+ Some(keyValueEncoderMap.get(colFamilyName)._1.getColumnFamilyIdBytes())
+ } else {
+ None
+ }
+ }
+
+ private val colFamilyIdBytesOpt: Option[Array[Byte]] = getColFamilyIdBytes
+
override def getNext(): (RecordType.Value, UnsafeRow, UnsafeRow, Long) = {
- val reader = currentChangelogReader()
- if (reader == null) {
- return null
+ var currRecord: (RecordType.Value, Array[Byte], Array[Byte]) = null
+ val currEncoder: (RocksDBKeyStateEncoder, RocksDBValueStateEncoder) =
+ keyValueEncoderMap.get(colFamilyNameOpt
+ .getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME))
+
+ if (colFamilyIdBytesOpt.isDefined) {
+ // If we are reading records for a particular column family, the
corresponding vcf id
+ // will be encoded in the key byte array. We need to extract that and
compare for the
+ // expected column family id. If it matches, we return the record. If
not, we move to
+ // the next record. Note that this has be handled across multiple
changelog files and we
+ // rely on the currentChangelogReader to move to the next changelog file
when needed.
+ while (currRecord == null) {
+ val reader = currentChangelogReader()
+ if (reader == null) {
+ return null
+ }
+
+ val nextRecord = reader.next()
+ val colFamilyIdBytes: Array[Byte] = colFamilyIdBytesOpt.get
+ val endIndex = colFamilyIdBytes.size
+ // Function checks for byte arrays being equal
+ // from index 0 to endIndex - 1 (both inclusive)
+ if (java.util.Arrays.equals(nextRecord._2, 0, endIndex,
+ colFamilyIdBytes, 0, endIndex)) {
+ currRecord = nextRecord
+ }
+ }
+ } else {
+ val reader = currentChangelogReader()
+ if (reader == null) {
+ return null
+ }
+ currRecord = reader.next()
}
- val (recordType, keyArray, valueArray) = reader.next()
- // Todo: does not support multiple virtual column families
- val (rocksDBKeyStateEncoder, rocksDBValueStateEncoder) =
- keyValueEncoderMap.get(StateStore.DEFAULT_COL_FAMILY_NAME)
- val keyRow = rocksDBKeyStateEncoder.decodeKey(keyArray)
- if (valueArray == null) {
- (recordType, keyRow, null, currentChangelogVersion - 1)
+
+ val keyRow = currEncoder._1.decodeKey(currRecord._2)
+ if (currRecord._3 == null) {
+ (currRecord._1, keyRow, null, currentChangelogVersion - 1)
} else {
- val valueRow = rocksDBValueStateEncoder.decodeValue(valueArray)
- (recordType, keyRow, valueRow, currentChangelogVersion - 1)
+ val valueRow = currEncoder._2.decodeValue(currRecord._3)
+ (currRecord._1, keyRow, valueRow, currentChangelogVersion - 1)
}
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index d55a973a14e1..6e616cc71a80 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -519,10 +519,14 @@ trait SupportsFineGrainedReplay {
*
* @param startVersion starting changelog version
* @param endVersion ending changelog version
+ * @param colFamilyNameOpt optional column family name to read from
* @return iterator that gives tuple(recordType: [[RecordType.Value]],
nested key: [[UnsafeRow]],
* nested value: [[UnsafeRow]], batchId: [[Long]])
*/
- def getStateStoreChangeDataReader(startVersion: Long, endVersion: Long):
+ def getStateStoreChangeDataReader(
+ startVersion: Long,
+ endVersion: Long,
+ colFamilyNameOpt: Option[String] = None):
NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)]
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala
index 651d72da1609..e89550da37e0 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala
@@ -397,13 +397,15 @@ class StateStoreChangelogReaderV2(
* @param startVersion start version of the changelog file to read
* @param endVersion end version of the changelog file to read
* @param compressionCodec de-compression method using for reading changelog
file
+ * @param colFamilyNameOpt optional column family name to read from
*/
abstract class StateStoreChangeDataReader(
fm: CheckpointFileManager,
stateLocation: Path,
startVersion: Long,
endVersion: Long,
- compressionCodec: CompressionCodec)
+ compressionCodec: CompressionCodec,
+ colFamilyNameOpt: Option[String] = None)
extends NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)] with
Logging {
assert(startVersion >= 1)
@@ -451,9 +453,12 @@ abstract class StateStoreChangeDataReader(
finished = true
return null
}
- // Todo: Does not support StateStoreChangelogReaderV2
- changelogReader =
+
+ changelogReader = if (colFamilyNameOpt.isDefined) {
+ new StateStoreChangelogReaderV2(fm, fileIterator.next(),
compressionCodec)
+ } else {
new StateStoreChangelogReaderV1(fm, fileIterator.next(),
compressionCodec)
+ }
}
changelogReader
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
index bd047d1132fb..84c6eb54681a 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
@@ -186,18 +186,49 @@ class StateDataSourceTransformWithStateSuite extends
StateStoreMetricsTest
}
assert(ex1.isInstanceOf[StateDataSourceInvalidOptionValue])
assert(ex1.getMessage.contains("Registered timers are not available"))
+ }
+ }
+ }
- // TODO: this should be removed when readChangeFeed is supported for
value state
- val ex2 = intercept[Exception] {
- spark.read
+ testWithChangelogCheckpointingEnabled("state data source cdf integration - "
+
+ "value state with single variable") {
+ withTempDir { tempDir =>
+ withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+ classOf[RocksDBStateStoreProvider].getName,
+ SQLConf.SHUFFLE_PARTITIONS.key ->
+ TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new StatefulProcessorWithSingleValueVar(),
+ TimeMode.None(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath),
+ AddData(inputData, "a"),
+ CheckNewAnswer(("a", "1")),
+ AddData(inputData, "b"),
+ CheckNewAnswer(("b", "1")),
+ StopStream
+ )
+
+ val changeFeedDf = spark.read
.format("statestore")
.option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
.option(StateSourceOptions.STATE_VAR_NAME, "valueState")
- .option(StateSourceOptions.READ_CHANGE_FEED, "true")
+ .option(StateSourceOptions.READ_CHANGE_FEED, true)
.option(StateSourceOptions.CHANGE_START_BATCH_ID, 0)
.load()
- }
- assert(ex2.isInstanceOf[StateDataSourceConflictOptions])
+
+ val opDf = changeFeedDf.selectExpr(
+ "change_type",
+ "key.value AS groupingKey",
+ "value.id AS valueId", "value.name AS valueName",
+ "partition_id")
+
+ checkAnswer(opDf,
+ Seq(Row("update", "a", 1L, "dummyKey", 0), Row("update", "b", 1L,
"dummyKey", 1)))
}
}
}
@@ -260,19 +291,61 @@ class StateDataSourceTransformWithStateSuite extends
StateStoreMetricsTest
}
assert(ex.isInstanceOf[StateDataSourceInvalidOptionValue])
assert(ex.getMessage.contains("State variable non-exist is not
defined"))
+ }
+ }
+ }
- // TODO: this should be removed when readChangeFeed is supported for
TTL based state
- // variables
- val ex1 = intercept[Exception] {
- spark.read
- .format("statestore")
- .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
- .option(StateSourceOptions.STATE_VAR_NAME, "countState")
- .option(StateSourceOptions.READ_CHANGE_FEED, "true")
- .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0)
- .load()
+ testWithChangelogCheckpointingEnabled("state data source cdf integration - "
+
+ "value state with single variable and TTL") {
+ withTempDir { tempDir =>
+ withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+ classOf[RocksDBStateStoreProvider].getName,
+ SQLConf.SHUFFLE_PARTITIONS.key ->
+ TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new StatefulProcessorWithTTL(),
+ TimeMode.ProcessingTime(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath),
+ AddData(inputData, "a"),
+ AddData(inputData, "b"),
+ Execute { _ =>
+ // wait for the batch to run since we are using processing time
+ Thread.sleep(5000)
+ },
+ StopStream
+ )
+
+ val stateReaderDf = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+ .option(StateSourceOptions.STATE_VAR_NAME, "countState")
+ .option(StateSourceOptions.READ_CHANGE_FEED, true)
+ .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0)
+ .load()
+
+ val resultDf = stateReaderDf.selectExpr(
+ "key.value", "value.value", "value.ttlExpirationMs", "partition_id")
+
+ var count = 0L
+ resultDf.collect().foreach { row =>
+ count = count + 1
+ assert(row.getLong(2) > 0)
}
- assert(ex1.isInstanceOf[StateDataSourceConflictOptions])
+
+ // verify that 2 state rows are present
+ assert(count === 2)
+
+ val answerDf = stateReaderDf.selectExpr(
+ "change_type",
+ "key.value AS groupingKey",
+ "value.value.value AS valueId", "partition_id")
+ checkAnswer(answerDf,
+ Seq(Row("update", "a", 1L, 0), Row("update", "b", 1L, 1)))
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]