micheal-o commented on code in PR #53316:
URL: https://github.com/apache/spark/pull/53316#discussion_r2621514601
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala:
##########
@@ -268,13 +291,23 @@ class StateDataSource extends TableProvider with
DataSourceRegister with Logging
if (sourceOptions.readRegisteredTimers) {
stateVarName = TimerStateUtils.getTimerStateVarNames(timeMode)._1
}
-
- val stateVarInfoList = operatorProperties.stateVariables
- .filter(stateVar => stateVar.stateName == stateVarName)
- require(stateVarInfoList.size == 1, s"Failed to find unique state
variable info " +
- s"for state variable $stateVarName in operator
${sourceOptions.operatorId}")
- val stateVarInfo = stateVarInfoList.head
- transformWithStateVariableInfoOpt = Some(stateVarInfo)
+ if (sourceOptions.internalOnlyReadAllColumnFamilies) {
+ stateVariableInfos = operatorProperties.stateVariables
+ } else {
+ var stateVarInfoList = operatorProperties.stateVariables
+ .filter(stateVar => stateVar.stateName == stateVarName)
+ if (stateVarInfoList.isEmpty &&
Review Comment:
nit: add comment why
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -258,32 +266,116 @@ class StatePartitionAllColumnFamiliesReader(
partition: StateStoreInputPartition,
schema: StructType,
keyStateEncoderSpec: KeyStateEncoderSpec,
- stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema])
+ defaultStateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
+ stateSchemaProviderOpt: Option[StateSchemaProvider],
+ allColumnFamiliesReaderInfo: AllColumnFamiliesReaderInfo)
extends StatePartitionReaderBase(
storeConf,
hadoopConf, partition, schema,
- keyStateEncoderSpec, None, stateStoreColFamilySchemaOpt, None, None) {
+ keyStateEncoderSpec, None,
+ defaultStateStoreColFamilySchemaOpt,
+ stateSchemaProviderOpt, None) {
- private lazy val store: ReadStateStore = {
+ private val stateStoreColFamilySchemas =
allColumnFamiliesReaderInfo.colFamilySchemas
+ private val stateVariableInfos =
allColumnFamiliesReaderInfo.stateVariableInfos
+
+ private def isListType(colFamilyName: String): Boolean = {
+ SchemaUtil.checkVariableType(
+ stateVariableInfos.find(info => info.stateName == colFamilyName),
+ StateVariableType.ListState)
+ }
+
+ override protected lazy val provider: StateStoreProvider = {
+ val stateStoreId =
StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
+ partition.sourceOptions.operatorId, partition.partition,
partition.sourceOptions.storeName)
+ val stateStoreProviderId = StateStoreProviderId(stateStoreId,
partition.queryId)
+ val useColumnFamilies = stateStoreColFamilySchemas.length > 1
+ StateStoreProvider.createAndInit(
+ stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec,
+ useColumnFamilies, storeConf, hadoopConf.value,
+ useMultipleValuesPerKey = false, stateSchemaProviderOpt)
+ }
+
+
+ private def checkAllColFamiliesExist(
+ colFamilyNames: List[String], stateStore: StateStore
+ ): Unit = {
+ // Filter out DEFAULT column family from validation for two reasons:
+ // 1. Some operators (e.g., stream-stream join v3) don't include DEFAULT
in their schema
+ // because the underlying RocksDB creates "default" column family
automatically
+ // 2. The default column family schema is handled separately via
+ // defaultStateStoreColFamilySchemaOpt, so no need to verify it here
+ val actualCFs = colFamilyNames.toSet.filter(_ !=
StateStore.DEFAULT_COL_FAMILY_NAME)
+ val expectedCFs = stateStore.allColumnFamilyNames
+ .filter(_ != StateStore.DEFAULT_COL_FAMILY_NAME)
+
+ // Validation: All column families found in the checkpoint must be
declared in the schema.
+ // It's acceptable if some schema CFs are not in expectedCFs - this just
means those
+ // column families have no data yet in the checkpoint
+ // (they'll be created during registration).
+ // However, if the checkpoint contains CFs not in the schema, it indicates
a mismatch.
+ require(expectedCFs.subsetOf(actualCFs),
+ s"Checkpoint contains unexpected column families. " +
+ s"Column families in checkpoint but not in schema:
${expectedCFs.diff(actualCFs)}")
+ }
+
+ // Use a single store instance for both registering column families and
iteration.
+ // We cannot abort and then get a read store because abort() invalidates the
loaded version,
+ // causing getReadStore() to reload from checkpoint and clear the column
family registrations.
+ private lazy val store: StateStore = {
assert(getStartStoreUniqueId == getEndStoreUniqueId,
"Start and end store unique IDs must be the same when reading all column
families")
- provider.getReadStore(
+ val stateStore = provider.getStore(
partition.sourceOptions.batchId + 1,
getStartStoreUniqueId
)
+
+ // Register all column families from the schema
+ if (stateStoreColFamilySchemas.length > 1) {
+
checkAllColFamiliesExist(stateStoreColFamilySchemas.map(_.colFamilyName),
stateStore)
+ stateStoreColFamilySchemas.foreach { cfSchema =>
+ cfSchema.colFamilyName match {
+ case StateStore.DEFAULT_COL_FAMILY_NAME => // createAndInit has
registered default
+ case _ =>
+ val isInternal = cfSchema.colFamilyName.startsWith("$")
Review Comment:
nit: Lets add a util func for this in :
`StateStoreColumnFamilySchemaUtils.scala`
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala:
##########
@@ -177,7 +193,8 @@ class StateDataSource extends TableProvider with
DataSourceRegister with Logging
val stateVars = twsOperatorProperties.stateVariables
val stateVarInfo = stateVars.filter(stateVar => stateVar.stateName ==
stateVarName)
- if (stateVarInfo.size != 1) {
+ if (stateVarInfo.size != 1 &&
+
!StateStoreColumnFamilySchemaUtils.isInternalColFamilyTestOnly(stateVarName)) {
Review Comment:
nit: please add comment why
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/StateStoreColumnFamilySchemaUtils.scala:
##########
@@ -99,6 +99,17 @@ object StateStoreColumnFamilySchemaUtils {
def getStateNameFromCountIndexCFName(colFamilyName: String): String =
getStateName(COUNT_INDEX_PREFIX, colFamilyName)
+ /**
+ * Returns true if the column family is internal (starts with "$") and we
are in testing mode.
+ * This is used to allow internal column families to be read during tests.
+ *
+ * @param colFamilyName The name of the column family to check
+ * @return true if this is an internal column family and Utils.isTesting is
true
+ */
+ def isInternalColFamilyTestOnly(colFamilyName: String): Boolean = {
Review Comment:
nit: `isTestingInternalColFamily`
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -258,32 +266,116 @@ class StatePartitionAllColumnFamiliesReader(
partition: StateStoreInputPartition,
schema: StructType,
keyStateEncoderSpec: KeyStateEncoderSpec,
- stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema])
+ defaultStateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
+ stateSchemaProviderOpt: Option[StateSchemaProvider],
+ allColumnFamiliesReaderInfo: AllColumnFamiliesReaderInfo)
extends StatePartitionReaderBase(
storeConf,
hadoopConf, partition, schema,
- keyStateEncoderSpec, None, stateStoreColFamilySchemaOpt, None, None) {
+ keyStateEncoderSpec, None,
+ defaultStateStoreColFamilySchemaOpt,
+ stateSchemaProviderOpt, None) {
- private lazy val store: ReadStateStore = {
+ private val stateStoreColFamilySchemas =
allColumnFamiliesReaderInfo.colFamilySchemas
Review Comment:
we should make this a set to make sure no duplicate. So that we don't
process same cf twice
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -258,32 +266,116 @@ class StatePartitionAllColumnFamiliesReader(
partition: StateStoreInputPartition,
schema: StructType,
keyStateEncoderSpec: KeyStateEncoderSpec,
- stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema])
+ defaultStateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
+ stateSchemaProviderOpt: Option[StateSchemaProvider],
+ allColumnFamiliesReaderInfo: AllColumnFamiliesReaderInfo)
extends StatePartitionReaderBase(
storeConf,
hadoopConf, partition, schema,
- keyStateEncoderSpec, None, stateStoreColFamilySchemaOpt, None, None) {
+ keyStateEncoderSpec, None,
+ defaultStateStoreColFamilySchemaOpt,
+ stateSchemaProviderOpt, None) {
- private lazy val store: ReadStateStore = {
+ private val stateStoreColFamilySchemas =
allColumnFamiliesReaderInfo.colFamilySchemas
+ private val stateVariableInfos =
allColumnFamiliesReaderInfo.stateVariableInfos
+
+ private def isListType(colFamilyName: String): Boolean = {
+ SchemaUtil.checkVariableType(
+ stateVariableInfos.find(info => info.stateName == colFamilyName),
+ StateVariableType.ListState)
+ }
+
+ override protected lazy val provider: StateStoreProvider = {
+ val stateStoreId =
StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
+ partition.sourceOptions.operatorId, partition.partition,
partition.sourceOptions.storeName)
+ val stateStoreProviderId = StateStoreProviderId(stateStoreId,
partition.queryId)
+ val useColumnFamilies = stateStoreColFamilySchemas.length > 1
+ StateStoreProvider.createAndInit(
+ stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec,
+ useColumnFamilies, storeConf, hadoopConf.value,
+ useMultipleValuesPerKey = false, stateSchemaProviderOpt)
+ }
+
+
+ private def checkAllColFamiliesExist(
+ colFamilyNames: List[String], stateStore: StateStore
+ ): Unit = {
+ // Filter out DEFAULT column family from validation for two reasons:
+ // 1. Some operators (e.g., stream-stream join v3) don't include DEFAULT
in their schema
+ // because the underlying RocksDB creates "default" column family
automatically
+ // 2. The default column family schema is handled separately via
+ // defaultStateStoreColFamilySchemaOpt, so no need to verify it here
+ val actualCFs = colFamilyNames.toSet.filter(_ !=
StateStore.DEFAULT_COL_FAMILY_NAME)
+ val expectedCFs = stateStore.allColumnFamilyNames
+ .filter(_ != StateStore.DEFAULT_COL_FAMILY_NAME)
+
+ // Validation: All column families found in the checkpoint must be
declared in the schema.
+ // It's acceptable if some schema CFs are not in expectedCFs - this just
means those
+ // column families have no data yet in the checkpoint
+ // (they'll be created during registration).
+ // However, if the checkpoint contains CFs not in the schema, it indicates
a mismatch.
+ require(expectedCFs.subsetOf(actualCFs),
+ s"Checkpoint contains unexpected column families. " +
Review Comment:
Fix the message, it should be: Some column families are present in the state
store but missing in the metadata.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala:
##########
@@ -310,6 +310,13 @@ trait StateStore extends ReadStateStore {
* Whether all updates have been committed
*/
def hasCommitted: Boolean
+
+ /**
+ * Returns all column family names in this state store.
+ *
+ * @return Set of all column family names
+ */
+ def allColumnFamilyNames: collection.Set[String]
Review Comment:
This should be added under `ReadStateStore` instead, so that we can call it
for both ReadStateStore and StateStore, since it is a read only operation
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala:
##########
@@ -420,8 +534,9 @@ class StatePartitionAllColumnFamiliesReaderSuite extends
StateDataSourceTestBase
}
}
- def testStreamStreamJoin(stateVersion: Int): Unit = {
- withSQLConf(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key ->
stateVersion.toString) {
+ def testStreamStreamJoinV2(stateVersion: Int): Unit = {
Review Comment:
nit: testStreamStreamJoinV1AndV2. And assert that the version passed in is
<=2
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala:
##########
@@ -533,4 +643,399 @@ class StatePartitionAllColumnFamiliesReaderSuite extends
StateDataSourceTestBase
}
}
}
+
+ test("SPARK-54419: transformWithState with multiple column families") {
+ withTempDir { tempDir =>
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new MultiStateVarProcessor(),
+ TimeMode.None(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath),
+ AddData(inputData, "a", "b", "a"),
+ CheckNewAnswer(("a", "2"), ("b", "1")),
+ AddData(inputData, "b", "c"),
+ CheckNewAnswer(("b", "2"), ("c", "1")),
+ StopStream
+ )
+
+ // Read all column families using internalOnlyReadAllColumnFamilies
+ val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+ validateBytesReadDfSchema(bytesDf)
+ val allBytesData = bytesDf.collect()
+
+ val columnFamilies = allBytesData.map(_.getString(3)).distinct.sorted
+
+ // Verify countState column family exists
+ assert(columnFamilies.toSet ==
+ Set("countState", "itemsList", "$rowCounter_itemsList", "itemsMap"))
+
+ // Define schemas for each column family based on provided schema info
+ val groupByKeySchema = StructType(Array(
+ StructField("value", StringType, nullable = true)
+ ))
+ val countStateValueSchema = StructType(Array(
+ StructField("value", LongType, nullable = false)
+ ))
+ val itemsListValueSchema = StructType(Array(
+ StructField("value", StringType, nullable = true)
+ ))
+ val rowCounterValueSchema = StructType(Array(
+ StructField("count", LongType, nullable = true)
+ ))
+ val itemsMapKeySchema = StructType(Array(
+ StructField("key", groupByKeySchema),
+ StructField("user_map_key", groupByKeySchema, nullable = true)
+ ))
+ val itemsMapValueSchema = StructType(Array(
+ StructField("user_map_value", IntegerType, nullable = true)
+ ))
+
+ // Validate countState
+ readAndValidateStateVar(
+ tempDir.getAbsolutePath, allBytesData,
+ stateVarName = "countState", groupByKeySchema, countStateValueSchema)
+
+ // Validate itemsList
+ readAndValidateStateVar(
+ tempDir.getAbsolutePath, allBytesData,
+ stateVarName = "itemsList", groupByKeySchema, itemsListValueSchema,
+ extraOptions = Map(StateSourceOptions.FLATTEN_COLLECTION_TYPES ->
"true"),
+ selectExprs = Seq("partition_id", "key", "list_element"))
+
+ // Validate $rowCounter_itemsList - intentionally reuses countState's
data
+ val countStateNormalDf = getNormalReadDf(tempDir.getAbsolutePath,
Option("countState"))
+ compareNormalAndBytesData(
+ countStateNormalDf.collect(),
+ allBytesData,
+ "$rowCounter_itemsList",
+ groupByKeySchema,
+ rowCounterValueSchema)
+
+ // Validate itemsMap
+ readAndValidateStateVar(
+ tempDir.getAbsolutePath, allBytesData,
+ stateVarName = "itemsMap", itemsMapKeySchema, itemsMapValueSchema,
+ selectExprs = Seq("partition_id", "STRUCT(key, user_map_key) AS KEY",
+ "user_map_value AS value"))
+ }
+ }
+
+ test("SPARK-54419: read all column families with event time timers") {
+ withTempDir { tempDir =>
+ val inputData = MemoryStream[(String, Long)]
+ val result = inputData.toDS()
+ .select(col("_1").as("key"),
timestamp_seconds(col("_2")).as("eventTime"))
+ .withWatermark("eventTime", "10 seconds")
+ .as[(String, Timestamp)]
+ .groupByKey(_._1)
+ .transformWithState(
+ new EventTimeTimerProcessor(),
+ TimeMode.EventTime(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath),
+ AddData(inputData, ("a", 1L), ("b", 2L), ("c", 3L)),
+ CheckLastBatch(("a", "1"), ("b", "1"), ("c", "1")),
+ StopStream
+ )
+
+ validateTimerColumnFamilies(tempDir.getAbsolutePath, "event")
+ }
+ }
+
+ test("SPARK-54419: read all column families with processing time timers") {
+ withTempDir { tempDir =>
+ val clock = new StreamManualClock
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new
RunningCountStatefulProcessorWithProcTimeTimer(),
+ TimeMode.ProcessingTime(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath,
+ trigger = Trigger.ProcessingTime("1 second"),
+ triggerClock = clock),
+ AddData(inputData, "a"),
+ AdvanceManualClock(1 * 1000),
+ CheckNewAnswer(("a", "1")),
+ StopStream
+ )
+
+ validateTimerColumnFamilies(tempDir.getAbsolutePath, "proc")
+ }
+ }
+
+ test("SPARK-54419: transformWithState with list state and TTL") {
+ withTempDir { tempDir =>
+ val clock = new StreamManualClock
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new ListStateTTLProcessor(),
+ TimeMode.ProcessingTime(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath,
+ trigger = Trigger.ProcessingTime("1 second"),
+ triggerClock = clock),
+ AddData(inputData, "a", "b", "a"),
+ AdvanceManualClock(1 * 1000),
+ CheckNewAnswer(("a", "2"), ("b", "1")),
+ StopStream
+ )
+
+ val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+ validateBytesReadDfSchema(bytesDf)
+
+ val allBytesData = bytesDf.collect()
+ val columnFamilies = allBytesData.map(_.getString(3)).distinct.sorted
+
+ assert(columnFamilies.toSet ==
+ Set("listState", "$ttl_listState", "$min_listState",
"$count_listState"))
+
+ // Define schemas for list state with TTL column families
+ val groupByKeySchema = StructType(Array(
+ StructField("value", StringType, nullable = true)
+ ))
+ val listStateValueSchema = StructType(Array(
+ StructField("value", StructType(Array(
+ StructField("value", StringType, nullable = true)
+ )), nullable = false),
+ StructField("ttlExpirationMs", LongType, nullable = false)
+ ))
+
+ val listStateNormalDf = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+ .option(StateSourceOptions.STATE_VAR_NAME, "listState")
+ .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, "true")
+ .load()
+ .selectExpr("partition_id", "key", "list_element")
+
+ compareNormalAndBytesData(
+ listStateNormalDf.collect(),
+ allBytesData,
+ "listState",
+ groupByKeySchema,
+ listStateValueSchema)
+ val dummyValueSchema = StructType(Array(StructField("__dummy__",
NullType)))
+ val ttlIndexKeySchema = StructType(Array(
+ StructField("expirationMs", LongType, nullable = false),
+ StructField("elementKey", groupByKeySchema)
+ ))
+ val minExpiryValueSchema = StructType(Array(
+ StructField("minExpiry", LongType)
+ ))
+ val countValueSchema = StructType(Array(
+ StructField("count", LongType)
+ ))
+ val columnFamilyAndKeyValueSchema = Seq(
+ ("$ttl_listState", ttlIndexKeySchema, dummyValueSchema),
+ ("$min_listState", groupByKeySchema, minExpiryValueSchema),
+ ("$count_listState", groupByKeySchema, countValueSchema)
+ )
+ columnFamilyAndKeyValueSchema.foreach(pair => {
+ val normalDf = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+ .option(StateSourceOptions.STATE_VAR_NAME, pair._1)
+ .load()
+ .selectExpr("partition_id", "key", "value")
+
+ compareNormalAndBytesData(
+ normalDf.collect(),
+ allBytesData,
+ pair._1,
+ pair._2,
+ pair._3)
+ }
+ )
+ }
+ }
+
+ def testStreamStreamJoinV3(): Unit = {
+ withSQLConf(
+ SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key -> "3"
+ ) {
+ withTempDir { tempDir =>
+ val inputData = MemoryStream[(Int, Long)]
+ val query = getStreamStreamJoinQuery(inputData)
+ testStream(query)(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath),
+ AddData(inputData, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)),
+ ProcessAllAvailable(),
+ StopStream
+ )
+ val stateBytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+ validateBytesReadDfSchema(stateBytesDf)
+
+ Seq("right-keyToNumValues",
"left-keyToNumValues").foreach(colFamilyName => {
+ val normalDf = getNormalReadDf(tempDir.getAbsolutePath,
Option(colFamilyName))
+
+ val keyToNumValuesKeySchema = StructType(Array(
+ StructField("key", IntegerType)
+ ))
+ val keyToNumValueValueSchema = StructType(Array(
+ StructField("value", LongType)
+ ))
+
+ compareNormalAndBytesData(
+ normalDf.collect(),
+ stateBytesDf.collect(),
+ colFamilyName,
+ keyToNumValuesKeySchema,
+ keyToNumValueValueSchema)
+ })
+
+ Seq("right-keyWithIndexToValue",
"left-keyWithIndexToValue").foreach(colFamilyName => {
+ val normalDf = getNormalReadDf(tempDir.getAbsolutePath,
Option(colFamilyName))
+ val keyToNumValuesKeySchema = StructType(Array(
+ StructField("key", IntegerType, nullable = false),
+ StructField("index", LongType)
+ ))
+ val keyToNumValueValueSchema = StructType(Array(
+ StructField("value", IntegerType, nullable = false),
+ StructField("time", TimestampType, nullable = false),
+ StructField("matched", BooleanType)
+ ))
+
+ compareNormalAndBytesData(
+ normalDf.collect(),
+ stateBytesDf.collect(),
+ colFamilyName,
+ keyToNumValuesKeySchema,
+ keyToNumValueValueSchema)
+ })
+ }
+ }
+ }
+
+ test("SPARK-54419: stream-stream joinV3") {
+ testStreamStreamJoinV3()
Review Comment:
fyi, no need to create separate func, if it is only used by one test case
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala:
##########
@@ -533,4 +643,399 @@ class StatePartitionAllColumnFamiliesReaderSuite extends
StateDataSourceTestBase
}
}
}
+
+ test("SPARK-54419: transformWithState with multiple column families") {
+ withTempDir { tempDir =>
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new MultiStateVarProcessor(),
+ TimeMode.None(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath),
+ AddData(inputData, "a", "b", "a"),
+ CheckNewAnswer(("a", "2"), ("b", "1")),
+ AddData(inputData, "b", "c"),
+ CheckNewAnswer(("b", "2"), ("c", "1")),
+ StopStream
+ )
+
+ // Read all column families using internalOnlyReadAllColumnFamilies
+ val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+ validateBytesReadDfSchema(bytesDf)
+ val allBytesData = bytesDf.collect()
+
+ val columnFamilies = allBytesData.map(_.getString(3)).distinct.sorted
+
+ // Verify countState column family exists
+ assert(columnFamilies.toSet ==
+ Set("countState", "itemsList", "$rowCounter_itemsList", "itemsMap"))
+
+ // Define schemas for each column family based on provided schema info
+ val groupByKeySchema = StructType(Array(
+ StructField("value", StringType, nullable = true)
+ ))
+ val countStateValueSchema = StructType(Array(
+ StructField("value", LongType, nullable = false)
+ ))
+ val itemsListValueSchema = StructType(Array(
+ StructField("value", StringType, nullable = true)
+ ))
+ val rowCounterValueSchema = StructType(Array(
+ StructField("count", LongType, nullable = true)
+ ))
+ val itemsMapKeySchema = StructType(Array(
+ StructField("key", groupByKeySchema),
+ StructField("user_map_key", groupByKeySchema, nullable = true)
+ ))
+ val itemsMapValueSchema = StructType(Array(
+ StructField("user_map_value", IntegerType, nullable = true)
+ ))
+
+ // Validate countState
+ readAndValidateStateVar(
+ tempDir.getAbsolutePath, allBytesData,
+ stateVarName = "countState", groupByKeySchema, countStateValueSchema)
+
+ // Validate itemsList
+ readAndValidateStateVar(
+ tempDir.getAbsolutePath, allBytesData,
+ stateVarName = "itemsList", groupByKeySchema, itemsListValueSchema,
+ extraOptions = Map(StateSourceOptions.FLATTEN_COLLECTION_TYPES ->
"true"),
+ selectExprs = Seq("partition_id", "key", "list_element"))
+
+ // Validate $rowCounter_itemsList - intentionally reuses countState's
data
+ val countStateNormalDf = getNormalReadDf(tempDir.getAbsolutePath,
Option("countState"))
+ compareNormalAndBytesData(
+ countStateNormalDf.collect(),
+ allBytesData,
+ "$rowCounter_itemsList",
+ groupByKeySchema,
+ rowCounterValueSchema)
+
+ // Validate itemsMap
+ readAndValidateStateVar(
+ tempDir.getAbsolutePath, allBytesData,
+ stateVarName = "itemsMap", itemsMapKeySchema, itemsMapValueSchema,
+ selectExprs = Seq("partition_id", "STRUCT(key, user_map_key) AS KEY",
+ "user_map_value AS value"))
+ }
+ }
+
+ test("SPARK-54419: read all column families with event time timers") {
+ withTempDir { tempDir =>
+ val inputData = MemoryStream[(String, Long)]
+ val result = inputData.toDS()
+ .select(col("_1").as("key"),
timestamp_seconds(col("_2")).as("eventTime"))
+ .withWatermark("eventTime", "10 seconds")
+ .as[(String, Timestamp)]
+ .groupByKey(_._1)
+ .transformWithState(
+ new EventTimeTimerProcessor(),
+ TimeMode.EventTime(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath),
+ AddData(inputData, ("a", 1L), ("b", 2L), ("c", 3L)),
+ CheckLastBatch(("a", "1"), ("b", "1"), ("c", "1")),
+ StopStream
+ )
+
+ validateTimerColumnFamilies(tempDir.getAbsolutePath, "event")
+ }
+ }
+
+ test("SPARK-54419: read all column families with processing time timers") {
+ withTempDir { tempDir =>
+ val clock = new StreamManualClock
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new
RunningCountStatefulProcessorWithProcTimeTimer(),
+ TimeMode.ProcessingTime(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath,
+ trigger = Trigger.ProcessingTime("1 second"),
+ triggerClock = clock),
+ AddData(inputData, "a"),
+ AdvanceManualClock(1 * 1000),
+ CheckNewAnswer(("a", "1")),
+ StopStream
+ )
+
+ validateTimerColumnFamilies(tempDir.getAbsolutePath, "proc")
+ }
+ }
+
+ test("SPARK-54419: transformWithState with list state and TTL") {
+ withTempDir { tempDir =>
+ val clock = new StreamManualClock
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new ListStateTTLProcessor(),
+ TimeMode.ProcessingTime(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath,
+ trigger = Trigger.ProcessingTime("1 second"),
+ triggerClock = clock),
+ AddData(inputData, "a", "b", "a"),
+ AdvanceManualClock(1 * 1000),
+ CheckNewAnswer(("a", "2"), ("b", "1")),
+ StopStream
+ )
+
+ val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+ validateBytesReadDfSchema(bytesDf)
+
+ val allBytesData = bytesDf.collect()
+ val columnFamilies = allBytesData.map(_.getString(3)).distinct.sorted
+
+ assert(columnFamilies.toSet ==
+ Set("listState", "$ttl_listState", "$min_listState",
"$count_listState"))
+
+ // Define schemas for list state with TTL column families
+ val groupByKeySchema = StructType(Array(
+ StructField("value", StringType, nullable = true)
+ ))
+ val listStateValueSchema = StructType(Array(
+ StructField("value", StructType(Array(
+ StructField("value", StringType, nullable = true)
+ )), nullable = false),
+ StructField("ttlExpirationMs", LongType, nullable = false)
+ ))
+
+ val listStateNormalDf = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+ .option(StateSourceOptions.STATE_VAR_NAME, "listState")
+ .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, "true")
+ .load()
+ .selectExpr("partition_id", "key", "list_element")
+
+ compareNormalAndBytesData(
+ listStateNormalDf.collect(),
+ allBytesData,
+ "listState",
+ groupByKeySchema,
+ listStateValueSchema)
+ val dummyValueSchema = StructType(Array(StructField("__dummy__",
NullType)))
+ val ttlIndexKeySchema = StructType(Array(
+ StructField("expirationMs", LongType, nullable = false),
+ StructField("elementKey", groupByKeySchema)
+ ))
+ val minExpiryValueSchema = StructType(Array(
+ StructField("minExpiry", LongType)
+ ))
+ val countValueSchema = StructType(Array(
+ StructField("count", LongType)
+ ))
+ val columnFamilyAndKeyValueSchema = Seq(
+ ("$ttl_listState", ttlIndexKeySchema, dummyValueSchema),
+ ("$min_listState", groupByKeySchema, minExpiryValueSchema),
+ ("$count_listState", groupByKeySchema, countValueSchema)
+ )
+ columnFamilyAndKeyValueSchema.foreach(pair => {
+ val normalDf = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+ .option(StateSourceOptions.STATE_VAR_NAME, pair._1)
+ .load()
+ .selectExpr("partition_id", "key", "value")
+
+ compareNormalAndBytesData(
+ normalDf.collect(),
+ allBytesData,
+ pair._1,
+ pair._2,
+ pair._3)
+ }
+ )
+ }
+ }
+
+ def testStreamStreamJoinV3(): Unit = {
+ withSQLConf(
+ SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key -> "3"
+ ) {
+ withTempDir { tempDir =>
+ val inputData = MemoryStream[(Int, Long)]
+ val query = getStreamStreamJoinQuery(inputData)
+ testStream(query)(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath),
+ AddData(inputData, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)),
+ ProcessAllAvailable(),
+ StopStream
+ )
+ val stateBytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+ validateBytesReadDfSchema(stateBytesDf)
+
+ Seq("right-keyToNumValues",
"left-keyToNumValues").foreach(colFamilyName => {
+ val normalDf = getNormalReadDf(tempDir.getAbsolutePath,
Option(colFamilyName))
+
+ val keyToNumValuesKeySchema = StructType(Array(
+ StructField("key", IntegerType)
+ ))
+ val keyToNumValueValueSchema = StructType(Array(
+ StructField("value", LongType)
+ ))
+
+ compareNormalAndBytesData(
+ normalDf.collect(),
+ stateBytesDf.collect(),
+ colFamilyName,
+ keyToNumValuesKeySchema,
+ keyToNumValueValueSchema)
+ })
+
+ Seq("right-keyWithIndexToValue",
"left-keyWithIndexToValue").foreach(colFamilyName => {
+ val normalDf = getNormalReadDf(tempDir.getAbsolutePath,
Option(colFamilyName))
+ val keyToNumValuesKeySchema = StructType(Array(
+ StructField("key", IntegerType, nullable = false),
+ StructField("index", LongType)
+ ))
+ val keyToNumValueValueSchema = StructType(Array(
+ StructField("value", IntegerType, nullable = false),
+ StructField("time", TimestampType, nullable = false),
+ StructField("matched", BooleanType)
+ ))
+
+ compareNormalAndBytesData(
+ normalDf.collect(),
+ stateBytesDf.collect(),
+ colFamilyName,
+ keyToNumValuesKeySchema,
+ keyToNumValueValueSchema)
+ })
+ }
+ }
+ }
+
+ test("SPARK-54419: stream-stream joinV3") {
+ testStreamStreamJoinV3()
+ }
+}
+
+/**
+ * Stateful processor with multiple state variables (ValueState + ListState)
+ * for testing multi-column family reading.
+ */
+class MultiStateVarProcessor extends StatefulProcessor[String, String,
(String, String)] {
+ @transient private var _countState: ValueState[Long] = _
+ @transient private var _itemsList: ListState[String] = _
+ @transient private var _itemsMap: MapState[String, SimpleMapValue] = _
+
+ override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
+ _countState = getHandle.getValueState[Long]("countState",
Encoders.scalaLong, TTLConfig.NONE)
+ _itemsList = getHandle.getListState[String]("itemsList", Encoders.STRING,
TTLConfig.NONE)
+ _itemsMap = getHandle.getMapState[String, SimpleMapValue](
+ "itemsMap", Encoders.STRING, Encoders.product[SimpleMapValue],
TTLConfig.NONE)
+ }
+
+ override def handleInputRows(
+ key: String,
+ inputRows: Iterator[String],
+ timerValues: TimerValues): Iterator[(String, String)] = {
+ val currentCount = Option(_countState.get()).getOrElse(0L)
+ var newCount = currentCount
+ inputRows.foreach { item =>
+ newCount += 1
+ _itemsList.appendValue(item)
+ _itemsMap.updateValue(item, SimpleMapValue(newCount.toInt))
+ }
+ _countState.update(newCount)
+ Iterator((key, newCount.toString))
+ }
+}
+
Review Comment:
Is there a reason why you're implementing your own processors instead of
reusing the ones in TransformWithStateSuite:
https://github.com/apache/spark/blob/25c57d87382ca90850541d4d20447bfd842ec85a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala#L606
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala:
##########
@@ -533,4 +643,399 @@ class StatePartitionAllColumnFamiliesReaderSuite extends
StateDataSourceTestBase
}
}
}
+
+ test("SPARK-54419: transformWithState with multiple column families") {
+ withTempDir { tempDir =>
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new MultiStateVarProcessor(),
+ TimeMode.None(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath),
+ AddData(inputData, "a", "b", "a"),
+ CheckNewAnswer(("a", "2"), ("b", "1")),
+ AddData(inputData, "b", "c"),
+ CheckNewAnswer(("b", "2"), ("c", "1")),
+ StopStream
+ )
+
+ // Read all column families using internalOnlyReadAllColumnFamilies
+ val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+ validateBytesReadDfSchema(bytesDf)
+ val allBytesData = bytesDf.collect()
+
+ val columnFamilies = allBytesData.map(_.getString(3)).distinct.sorted
+
+ // Verify countState column family exists
+ assert(columnFamilies.toSet ==
+ Set("countState", "itemsList", "$rowCounter_itemsList", "itemsMap"))
+
+ // Define schemas for each column family based on provided schema info
+ val groupByKeySchema = StructType(Array(
+ StructField("value", StringType, nullable = true)
+ ))
+ val countStateValueSchema = StructType(Array(
+ StructField("value", LongType, nullable = false)
+ ))
+ val itemsListValueSchema = StructType(Array(
+ StructField("value", StringType, nullable = true)
+ ))
+ val rowCounterValueSchema = StructType(Array(
+ StructField("count", LongType, nullable = true)
+ ))
+ val itemsMapKeySchema = StructType(Array(
+ StructField("key", groupByKeySchema),
+ StructField("user_map_key", groupByKeySchema, nullable = true)
+ ))
+ val itemsMapValueSchema = StructType(Array(
+ StructField("user_map_value", IntegerType, nullable = true)
+ ))
+
+ // Validate countState
+ readAndValidateStateVar(
+ tempDir.getAbsolutePath, allBytesData,
+ stateVarName = "countState", groupByKeySchema, countStateValueSchema)
+
+ // Validate itemsList
+ readAndValidateStateVar(
+ tempDir.getAbsolutePath, allBytesData,
+ stateVarName = "itemsList", groupByKeySchema, itemsListValueSchema,
+ extraOptions = Map(StateSourceOptions.FLATTEN_COLLECTION_TYPES ->
"true"),
+ selectExprs = Seq("partition_id", "key", "list_element"))
+
+ // Validate $rowCounter_itemsList - intentionally reuses countState's
data
+ val countStateNormalDf = getNormalReadDf(tempDir.getAbsolutePath,
Option("countState"))
+ compareNormalAndBytesData(
+ countStateNormalDf.collect(),
+ allBytesData,
+ "$rowCounter_itemsList",
+ groupByKeySchema,
+ rowCounterValueSchema)
+
+ // Validate itemsMap
+ readAndValidateStateVar(
+ tempDir.getAbsolutePath, allBytesData,
+ stateVarName = "itemsMap", itemsMapKeySchema, itemsMapValueSchema,
+ selectExprs = Seq("partition_id", "STRUCT(key, user_map_key) AS KEY",
+ "user_map_value AS value"))
+ }
+ }
+
+ test("SPARK-54419: read all column families with event time timers") {
+ withTempDir { tempDir =>
+ val inputData = MemoryStream[(String, Long)]
+ val result = inputData.toDS()
+ .select(col("_1").as("key"),
timestamp_seconds(col("_2")).as("eventTime"))
+ .withWatermark("eventTime", "10 seconds")
+ .as[(String, Timestamp)]
+ .groupByKey(_._1)
+ .transformWithState(
+ new EventTimeTimerProcessor(),
+ TimeMode.EventTime(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath),
+ AddData(inputData, ("a", 1L), ("b", 2L), ("c", 3L)),
+ CheckLastBatch(("a", "1"), ("b", "1"), ("c", "1")),
+ StopStream
+ )
+
+ validateTimerColumnFamilies(tempDir.getAbsolutePath, "event")
+ }
+ }
+
+ test("SPARK-54419: read all column families with processing time timers") {
+ withTempDir { tempDir =>
+ val clock = new StreamManualClock
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new
RunningCountStatefulProcessorWithProcTimeTimer(),
+ TimeMode.ProcessingTime(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath,
+ trigger = Trigger.ProcessingTime("1 second"),
+ triggerClock = clock),
+ AddData(inputData, "a"),
+ AdvanceManualClock(1 * 1000),
+ CheckNewAnswer(("a", "1")),
+ StopStream
+ )
+
+ validateTimerColumnFamilies(tempDir.getAbsolutePath, "proc")
+ }
+ }
+
+ test("SPARK-54419: transformWithState with list state and TTL") {
+ withTempDir { tempDir =>
+ val clock = new StreamManualClock
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new ListStateTTLProcessor(),
+ TimeMode.ProcessingTime(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath,
+ trigger = Trigger.ProcessingTime("1 second"),
+ triggerClock = clock),
+ AddData(inputData, "a", "b", "a"),
+ AdvanceManualClock(1 * 1000),
+ CheckNewAnswer(("a", "2"), ("b", "1")),
+ StopStream
+ )
+
+ val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+ validateBytesReadDfSchema(bytesDf)
+
+ val allBytesData = bytesDf.collect()
+ val columnFamilies = allBytesData.map(_.getString(3)).distinct.sorted
+
+ assert(columnFamilies.toSet ==
+ Set("listState", "$ttl_listState", "$min_listState",
"$count_listState"))
+
+ // Define schemas for list state with TTL column families
+ val groupByKeySchema = StructType(Array(
+ StructField("value", StringType, nullable = true)
+ ))
+ val listStateValueSchema = StructType(Array(
+ StructField("value", StructType(Array(
+ StructField("value", StringType, nullable = true)
+ )), nullable = false),
+ StructField("ttlExpirationMs", LongType, nullable = false)
+ ))
+
+ val listStateNormalDf = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+ .option(StateSourceOptions.STATE_VAR_NAME, "listState")
+ .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, "true")
+ .load()
+ .selectExpr("partition_id", "key", "list_element")
+
+ compareNormalAndBytesData(
+ listStateNormalDf.collect(),
+ allBytesData,
+ "listState",
+ groupByKeySchema,
+ listStateValueSchema)
+ val dummyValueSchema = StructType(Array(StructField("__dummy__",
NullType)))
+ val ttlIndexKeySchema = StructType(Array(
+ StructField("expirationMs", LongType, nullable = false),
+ StructField("elementKey", groupByKeySchema)
+ ))
+ val minExpiryValueSchema = StructType(Array(
+ StructField("minExpiry", LongType)
+ ))
+ val countValueSchema = StructType(Array(
+ StructField("count", LongType)
+ ))
+ val columnFamilyAndKeyValueSchema = Seq(
+ ("$ttl_listState", ttlIndexKeySchema, dummyValueSchema),
+ ("$min_listState", groupByKeySchema, minExpiryValueSchema),
+ ("$count_listState", groupByKeySchema, countValueSchema)
+ )
+ columnFamilyAndKeyValueSchema.foreach(pair => {
+ val normalDf = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+ .option(StateSourceOptions.STATE_VAR_NAME, pair._1)
+ .load()
+ .selectExpr("partition_id", "key", "value")
+
+ compareNormalAndBytesData(
+ normalDf.collect(),
+ allBytesData,
+ pair._1,
+ pair._2,
+ pair._3)
+ }
+ )
+ }
+ }
+
+ def testStreamStreamJoinV3(): Unit = {
+ withSQLConf(
+ SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key -> "3"
+ ) {
+ withTempDir { tempDir =>
+ val inputData = MemoryStream[(Int, Long)]
+ val query = getStreamStreamJoinQuery(inputData)
+ testStream(query)(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath),
+ AddData(inputData, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)),
+ ProcessAllAvailable(),
+ StopStream
+ )
+ val stateBytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+ validateBytesReadDfSchema(stateBytesDf)
+
+ Seq("right-keyToNumValues",
"left-keyToNumValues").foreach(colFamilyName => {
+ val normalDf = getNormalReadDf(tempDir.getAbsolutePath,
Option(colFamilyName))
+
+ val keyToNumValuesKeySchema = StructType(Array(
+ StructField("key", IntegerType)
+ ))
+ val keyToNumValueValueSchema = StructType(Array(
+ StructField("value", LongType)
+ ))
+
+ compareNormalAndBytesData(
+ normalDf.collect(),
+ stateBytesDf.collect(),
+ colFamilyName,
+ keyToNumValuesKeySchema,
+ keyToNumValueValueSchema)
+ })
+
+ Seq("right-keyWithIndexToValue",
"left-keyWithIndexToValue").foreach(colFamilyName => {
+ val normalDf = getNormalReadDf(tempDir.getAbsolutePath,
Option(colFamilyName))
+ val keyToNumValuesKeySchema = StructType(Array(
+ StructField("key", IntegerType, nullable = false),
+ StructField("index", LongType)
+ ))
+ val keyToNumValueValueSchema = StructType(Array(
+ StructField("value", IntegerType, nullable = false),
+ StructField("time", TimestampType, nullable = false),
+ StructField("matched", BooleanType)
+ ))
+
+ compareNormalAndBytesData(
+ normalDf.collect(),
+ stateBytesDf.collect(),
+ colFamilyName,
+ keyToNumValuesKeySchema,
+ keyToNumValueValueSchema)
+ })
+ }
+ }
+ }
+
+ test("SPARK-54419: stream-stream joinV3") {
+ testStreamStreamJoinV3()
+ }
+}
+
+/**
+ * Stateful processor with multiple state variables (ValueState + ListState)
+ * for testing multi-column family reading.
+ */
+class MultiStateVarProcessor extends StatefulProcessor[String, String,
(String, String)] {
+ @transient private var _countState: ValueState[Long] = _
+ @transient private var _itemsList: ListState[String] = _
+ @transient private var _itemsMap: MapState[String, SimpleMapValue] = _
+
+ override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
+ _countState = getHandle.getValueState[Long]("countState",
Encoders.scalaLong, TTLConfig.NONE)
+ _itemsList = getHandle.getListState[String]("itemsList", Encoders.STRING,
TTLConfig.NONE)
+ _itemsMap = getHandle.getMapState[String, SimpleMapValue](
+ "itemsMap", Encoders.STRING, Encoders.product[SimpleMapValue],
TTLConfig.NONE)
+ }
+
+ override def handleInputRows(
Review Comment:
nit: indentation
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala:
##########
@@ -610,6 +610,10 @@ private[sql] class RocksDBStateStoreProvider
override def hasCommitted: Boolean = state == COMMITTED
+ override def allColumnFamilyNames: collection.Set[String] = {
Review Comment:
ditto
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -336,6 +336,14 @@ class RocksDB(
colFamilyNameToInfoMap.asScala.values.toSeq.count(_.isInternal ==
isInternal)
}
+ /**
+ * Returns all column family names currently registered in RocksDB.
+ * This includes column families loaded from checkpoint metadata.
+ */
+ def allColumnFamilyNames: collection.Set[String] = {
Review Comment:
ditto
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala:
##########
@@ -146,6 +146,9 @@ private[sql] class HDFSBackedStateStoreProvider extends
StateStoreProvider with
throw StateStoreErrors.multipleColumnFamiliesNotSupported(providerName)
}
+ override def allColumnFamilyNames: collection.Set[String] =
Review Comment:
why not just use `Set[String]`? `collection.Set` is the base call of all
sets and here it means you can return either mutable or immutable set. Using
`Set[String]` makes it clear that we are only returning immutable
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala:
##########
@@ -533,4 +643,399 @@ class StatePartitionAllColumnFamiliesReaderSuite extends
StateDataSourceTestBase
}
}
}
+
+ test("SPARK-54419: transformWithState with multiple column families") {
+ withTempDir { tempDir =>
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new MultiStateVarProcessor(),
+ TimeMode.None(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath),
+ AddData(inputData, "a", "b", "a"),
+ CheckNewAnswer(("a", "2"), ("b", "1")),
+ AddData(inputData, "b", "c"),
+ CheckNewAnswer(("b", "2"), ("c", "1")),
+ StopStream
+ )
+
+ // Read all column families using internalOnlyReadAllColumnFamilies
+ val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+ validateBytesReadDfSchema(bytesDf)
+ val allBytesData = bytesDf.collect()
+
+ val columnFamilies = allBytesData.map(_.getString(3)).distinct.sorted
+
+ // Verify countState column family exists
+ assert(columnFamilies.toSet ==
+ Set("countState", "itemsList", "$rowCounter_itemsList", "itemsMap"))
+
+ // Define schemas for each column family based on provided schema info
+ val groupByKeySchema = StructType(Array(
+ StructField("value", StringType, nullable = true)
+ ))
+ val countStateValueSchema = StructType(Array(
+ StructField("value", LongType, nullable = false)
+ ))
+ val itemsListValueSchema = StructType(Array(
+ StructField("value", StringType, nullable = true)
+ ))
+ val rowCounterValueSchema = StructType(Array(
+ StructField("count", LongType, nullable = true)
+ ))
+ val itemsMapKeySchema = StructType(Array(
+ StructField("key", groupByKeySchema),
+ StructField("user_map_key", groupByKeySchema, nullable = true)
+ ))
+ val itemsMapValueSchema = StructType(Array(
+ StructField("user_map_value", IntegerType, nullable = true)
+ ))
+
+ // Validate countState
+ readAndValidateStateVar(
+ tempDir.getAbsolutePath, allBytesData,
+ stateVarName = "countState", groupByKeySchema, countStateValueSchema)
+
+ // Validate itemsList
+ readAndValidateStateVar(
+ tempDir.getAbsolutePath, allBytesData,
+ stateVarName = "itemsList", groupByKeySchema, itemsListValueSchema,
+ extraOptions = Map(StateSourceOptions.FLATTEN_COLLECTION_TYPES ->
"true"),
+ selectExprs = Seq("partition_id", "key", "list_element"))
+
+ // Validate $rowCounter_itemsList - intentionally reuses countState's
data
Review Comment:
nit: add why?
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -258,32 +266,116 @@ class StatePartitionAllColumnFamiliesReader(
partition: StateStoreInputPartition,
schema: StructType,
keyStateEncoderSpec: KeyStateEncoderSpec,
- stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema])
+ defaultStateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
+ stateSchemaProviderOpt: Option[StateSchemaProvider],
+ allColumnFamiliesReaderInfo: AllColumnFamiliesReaderInfo)
extends StatePartitionReaderBase(
storeConf,
hadoopConf, partition, schema,
- keyStateEncoderSpec, None, stateStoreColFamilySchemaOpt, None, None) {
+ keyStateEncoderSpec, None,
+ defaultStateStoreColFamilySchemaOpt,
+ stateSchemaProviderOpt, None) {
- private lazy val store: ReadStateStore = {
+ private val stateStoreColFamilySchemas =
allColumnFamiliesReaderInfo.colFamilySchemas
+ private val stateVariableInfos =
allColumnFamiliesReaderInfo.stateVariableInfos
+
+ private def isListType(colFamilyName: String): Boolean = {
+ SchemaUtil.checkVariableType(
+ stateVariableInfos.find(info => info.stateName == colFamilyName),
+ StateVariableType.ListState)
+ }
+
+ override protected lazy val provider: StateStoreProvider = {
+ val stateStoreId =
StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
+ partition.sourceOptions.operatorId, partition.partition,
partition.sourceOptions.storeName)
+ val stateStoreProviderId = StateStoreProviderId(stateStoreId,
partition.queryId)
+ val useColumnFamilies = stateStoreColFamilySchemas.length > 1
+ StateStoreProvider.createAndInit(
+ stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec,
+ useColumnFamilies, storeConf, hadoopConf.value,
+ useMultipleValuesPerKey = false, stateSchemaProviderOpt)
+ }
+
+
+ private def checkAllColFamiliesExist(
+ colFamilyNames: List[String], stateStore: StateStore
+ ): Unit = {
+ // Filter out DEFAULT column family from validation for two reasons:
+ // 1. Some operators (e.g., stream-stream join v3) don't include DEFAULT
in their schema
+ // because the underlying RocksDB creates "default" column family
automatically
+ // 2. The default column family schema is handled separately via
+ // defaultStateStoreColFamilySchemaOpt, so no need to verify it here
+ val actualCFs = colFamilyNames.toSet.filter(_ !=
StateStore.DEFAULT_COL_FAMILY_NAME)
+ val expectedCFs = stateStore.allColumnFamilyNames
+ .filter(_ != StateStore.DEFAULT_COL_FAMILY_NAME)
+
+ // Validation: All column families found in the checkpoint must be
declared in the schema.
+ // It's acceptable if some schema CFs are not in expectedCFs - this just
means those
+ // column families have no data yet in the checkpoint
+ // (they'll be created during registration).
+ // However, if the checkpoint contains CFs not in the schema, it indicates
a mismatch.
+ require(expectedCFs.subsetOf(actualCFs),
+ s"Checkpoint contains unexpected column families. " +
+ s"Column families in checkpoint but not in schema:
${expectedCFs.diff(actualCFs)}")
Review Comment:
nit: Column families in state store but not in metadata:
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala:
##########
@@ -533,4 +643,399 @@ class StatePartitionAllColumnFamiliesReaderSuite extends
StateDataSourceTestBase
}
}
}
+
+ test("SPARK-54419: transformWithState with multiple column families") {
+ withTempDir { tempDir =>
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new MultiStateVarProcessor(),
+ TimeMode.None(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath),
+ AddData(inputData, "a", "b", "a"),
+ CheckNewAnswer(("a", "2"), ("b", "1")),
+ AddData(inputData, "b", "c"),
+ CheckNewAnswer(("b", "2"), ("c", "1")),
+ StopStream
+ )
+
+ // Read all column families using internalOnlyReadAllColumnFamilies
+ val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+ validateBytesReadDfSchema(bytesDf)
+ val allBytesData = bytesDf.collect()
+
+ val columnFamilies = allBytesData.map(_.getString(3)).distinct.sorted
+
+ // Verify countState column family exists
+ assert(columnFamilies.toSet ==
+ Set("countState", "itemsList", "$rowCounter_itemsList", "itemsMap"))
+
+ // Define schemas for each column family based on provided schema info
+ val groupByKeySchema = StructType(Array(
+ StructField("value", StringType, nullable = true)
+ ))
+ val countStateValueSchema = StructType(Array(
+ StructField("value", LongType, nullable = false)
+ ))
+ val itemsListValueSchema = StructType(Array(
+ StructField("value", StringType, nullable = true)
+ ))
+ val rowCounterValueSchema = StructType(Array(
+ StructField("count", LongType, nullable = true)
+ ))
+ val itemsMapKeySchema = StructType(Array(
+ StructField("key", groupByKeySchema),
+ StructField("user_map_key", groupByKeySchema, nullable = true)
+ ))
+ val itemsMapValueSchema = StructType(Array(
+ StructField("user_map_value", IntegerType, nullable = true)
+ ))
+
+ // Validate countState
+ readAndValidateStateVar(
+ tempDir.getAbsolutePath, allBytesData,
+ stateVarName = "countState", groupByKeySchema, countStateValueSchema)
+
+ // Validate itemsList
+ readAndValidateStateVar(
+ tempDir.getAbsolutePath, allBytesData,
+ stateVarName = "itemsList", groupByKeySchema, itemsListValueSchema,
+ extraOptions = Map(StateSourceOptions.FLATTEN_COLLECTION_TYPES ->
"true"),
+ selectExprs = Seq("partition_id", "key", "list_element"))
+
+ // Validate $rowCounter_itemsList - intentionally reuses countState's
data
+ val countStateNormalDf = getNormalReadDf(tempDir.getAbsolutePath,
Option("countState"))
+ compareNormalAndBytesData(
+ countStateNormalDf.collect(),
+ allBytesData,
+ "$rowCounter_itemsList",
+ groupByKeySchema,
+ rowCounterValueSchema)
+
+ // Validate itemsMap
+ readAndValidateStateVar(
+ tempDir.getAbsolutePath, allBytesData,
+ stateVarName = "itemsMap", itemsMapKeySchema, itemsMapValueSchema,
+ selectExprs = Seq("partition_id", "STRUCT(key, user_map_key) AS KEY",
+ "user_map_value AS value"))
+ }
+ }
+
+ test("SPARK-54419: read all column families with event time timers") {
+ withTempDir { tempDir =>
+ val inputData = MemoryStream[(String, Long)]
+ val result = inputData.toDS()
+ .select(col("_1").as("key"),
timestamp_seconds(col("_2")).as("eventTime"))
+ .withWatermark("eventTime", "10 seconds")
+ .as[(String, Timestamp)]
+ .groupByKey(_._1)
+ .transformWithState(
+ new EventTimeTimerProcessor(),
+ TimeMode.EventTime(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath),
+ AddData(inputData, ("a", 1L), ("b", 2L), ("c", 3L)),
+ CheckLastBatch(("a", "1"), ("b", "1"), ("c", "1")),
+ StopStream
+ )
+
+ validateTimerColumnFamilies(tempDir.getAbsolutePath, "event")
+ }
+ }
+
+ test("SPARK-54419: read all column families with processing time timers") {
+ withTempDir { tempDir =>
+ val clock = new StreamManualClock
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new
RunningCountStatefulProcessorWithProcTimeTimer(),
+ TimeMode.ProcessingTime(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath,
+ trigger = Trigger.ProcessingTime("1 second"),
+ triggerClock = clock),
+ AddData(inputData, "a"),
+ AdvanceManualClock(1 * 1000),
+ CheckNewAnswer(("a", "1")),
+ StopStream
+ )
+
+ validateTimerColumnFamilies(tempDir.getAbsolutePath, "proc")
+ }
+ }
+
+ test("SPARK-54419: transformWithState with list state and TTL") {
Review Comment:
Should we add test for map and value TTL too
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala:
##########
@@ -533,4 +643,399 @@ class StatePartitionAllColumnFamiliesReaderSuite extends
StateDataSourceTestBase
}
}
}
+
+ test("SPARK-54419: transformWithState with multiple column families") {
+ withTempDir { tempDir =>
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new MultiStateVarProcessor(),
+ TimeMode.None(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath),
+ AddData(inputData, "a", "b", "a"),
+ CheckNewAnswer(("a", "2"), ("b", "1")),
+ AddData(inputData, "b", "c"),
+ CheckNewAnswer(("b", "2"), ("c", "1")),
+ StopStream
+ )
+
+ // Read all column families using internalOnlyReadAllColumnFamilies
+ val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+ validateBytesReadDfSchema(bytesDf)
+ val allBytesData = bytesDf.collect()
+
+ val columnFamilies = allBytesData.map(_.getString(3)).distinct.sorted
+
+ // Verify countState column family exists
+ assert(columnFamilies.toSet ==
+ Set("countState", "itemsList", "$rowCounter_itemsList", "itemsMap"))
+
+ // Define schemas for each column family based on provided schema info
+ val groupByKeySchema = StructType(Array(
+ StructField("value", StringType, nullable = true)
+ ))
+ val countStateValueSchema = StructType(Array(
+ StructField("value", LongType, nullable = false)
+ ))
+ val itemsListValueSchema = StructType(Array(
+ StructField("value", StringType, nullable = true)
+ ))
+ val rowCounterValueSchema = StructType(Array(
+ StructField("count", LongType, nullable = true)
+ ))
+ val itemsMapKeySchema = StructType(Array(
+ StructField("key", groupByKeySchema),
+ StructField("user_map_key", groupByKeySchema, nullable = true)
+ ))
+ val itemsMapValueSchema = StructType(Array(
+ StructField("user_map_value", IntegerType, nullable = true)
+ ))
+
+ // Validate countState
+ readAndValidateStateVar(
+ tempDir.getAbsolutePath, allBytesData,
+ stateVarName = "countState", groupByKeySchema, countStateValueSchema)
+
+ // Validate itemsList
+ readAndValidateStateVar(
+ tempDir.getAbsolutePath, allBytesData,
+ stateVarName = "itemsList", groupByKeySchema, itemsListValueSchema,
+ extraOptions = Map(StateSourceOptions.FLATTEN_COLLECTION_TYPES ->
"true"),
+ selectExprs = Seq("partition_id", "key", "list_element"))
+
+ // Validate $rowCounter_itemsList - intentionally reuses countState's
data
+ val countStateNormalDf = getNormalReadDf(tempDir.getAbsolutePath,
Option("countState"))
+ compareNormalAndBytesData(
+ countStateNormalDf.collect(),
+ allBytesData,
+ "$rowCounter_itemsList",
+ groupByKeySchema,
+ rowCounterValueSchema)
+
+ // Validate itemsMap
+ readAndValidateStateVar(
+ tempDir.getAbsolutePath, allBytesData,
+ stateVarName = "itemsMap", itemsMapKeySchema, itemsMapValueSchema,
+ selectExprs = Seq("partition_id", "STRUCT(key, user_map_key) AS KEY",
+ "user_map_value AS value"))
+ }
+ }
+
+ test("SPARK-54419: read all column families with event time timers") {
+ withTempDir { tempDir =>
+ val inputData = MemoryStream[(String, Long)]
+ val result = inputData.toDS()
+ .select(col("_1").as("key"),
timestamp_seconds(col("_2")).as("eventTime"))
+ .withWatermark("eventTime", "10 seconds")
+ .as[(String, Timestamp)]
+ .groupByKey(_._1)
+ .transformWithState(
+ new EventTimeTimerProcessor(),
+ TimeMode.EventTime(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath),
+ AddData(inputData, ("a", 1L), ("b", 2L), ("c", 3L)),
+ CheckLastBatch(("a", "1"), ("b", "1"), ("c", "1")),
+ StopStream
+ )
+
+ validateTimerColumnFamilies(tempDir.getAbsolutePath, "event")
+ }
+ }
+
+ test("SPARK-54419: read all column families with processing time timers") {
+ withTempDir { tempDir =>
+ val clock = new StreamManualClock
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new
RunningCountStatefulProcessorWithProcTimeTimer(),
+ TimeMode.ProcessingTime(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath,
+ trigger = Trigger.ProcessingTime("1 second"),
+ triggerClock = clock),
+ AddData(inputData, "a"),
+ AdvanceManualClock(1 * 1000),
+ CheckNewAnswer(("a", "1")),
+ StopStream
+ )
+
+ validateTimerColumnFamilies(tempDir.getAbsolutePath, "proc")
+ }
+ }
+
+ test("SPARK-54419: transformWithState with list state and TTL") {
+ withTempDir { tempDir =>
+ val clock = new StreamManualClock
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new ListStateTTLProcessor(),
+ TimeMode.ProcessingTime(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath,
+ trigger = Trigger.ProcessingTime("1 second"),
+ triggerClock = clock),
+ AddData(inputData, "a", "b", "a"),
+ AdvanceManualClock(1 * 1000),
+ CheckNewAnswer(("a", "2"), ("b", "1")),
+ StopStream
+ )
+
+ val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+ validateBytesReadDfSchema(bytesDf)
+
+ val allBytesData = bytesDf.collect()
+ val columnFamilies = allBytesData.map(_.getString(3)).distinct.sorted
+
+ assert(columnFamilies.toSet ==
+ Set("listState", "$ttl_listState", "$min_listState",
"$count_listState"))
+
+ // Define schemas for list state with TTL column families
+ val groupByKeySchema = StructType(Array(
+ StructField("value", StringType, nullable = true)
+ ))
+ val listStateValueSchema = StructType(Array(
+ StructField("value", StructType(Array(
+ StructField("value", StringType, nullable = true)
+ )), nullable = false),
+ StructField("ttlExpirationMs", LongType, nullable = false)
+ ))
+
+ val listStateNormalDf = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+ .option(StateSourceOptions.STATE_VAR_NAME, "listState")
+ .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, "true")
+ .load()
+ .selectExpr("partition_id", "key", "list_element")
+
+ compareNormalAndBytesData(
+ listStateNormalDf.collect(),
+ allBytesData,
+ "listState",
+ groupByKeySchema,
+ listStateValueSchema)
+ val dummyValueSchema = StructType(Array(StructField("__dummy__",
NullType)))
+ val ttlIndexKeySchema = StructType(Array(
+ StructField("expirationMs", LongType, nullable = false),
+ StructField("elementKey", groupByKeySchema)
+ ))
+ val minExpiryValueSchema = StructType(Array(
+ StructField("minExpiry", LongType)
+ ))
+ val countValueSchema = StructType(Array(
+ StructField("count", LongType)
+ ))
+ val columnFamilyAndKeyValueSchema = Seq(
+ ("$ttl_listState", ttlIndexKeySchema, dummyValueSchema),
+ ("$min_listState", groupByKeySchema, minExpiryValueSchema),
+ ("$count_listState", groupByKeySchema, countValueSchema)
+ )
+ columnFamilyAndKeyValueSchema.foreach(pair => {
+ val normalDf = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+ .option(StateSourceOptions.STATE_VAR_NAME, pair._1)
+ .load()
+ .selectExpr("partition_id", "key", "value")
+
+ compareNormalAndBytesData(
+ normalDf.collect(),
+ allBytesData,
+ pair._1,
+ pair._2,
+ pair._3)
+ }
+ )
+ }
+ }
+
+ def testStreamStreamJoinV3(): Unit = {
Review Comment:
This seems to be doing the exact thing as `testStreamStreamJoinV2` func. Why
not use a single func for both and pass in different params. Or am I missing
something?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]