zifeif2 commented on code in PR #53316:
URL: https://github.com/apache/spark/pull/53316#discussion_r2624893986
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala:
##########
@@ -533,4 +643,399 @@ class StatePartitionAllColumnFamiliesReaderSuite extends
StateDataSourceTestBase
}
}
}
+
+ test("SPARK-54419: transformWithState with multiple column families") {
+ withTempDir { tempDir =>
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new MultiStateVarProcessor(),
+ TimeMode.None(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath),
+ AddData(inputData, "a", "b", "a"),
+ CheckNewAnswer(("a", "2"), ("b", "1")),
+ AddData(inputData, "b", "c"),
+ CheckNewAnswer(("b", "2"), ("c", "1")),
+ StopStream
+ )
+
+ // Read all column families using internalOnlyReadAllColumnFamilies
+ val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+ validateBytesReadDfSchema(bytesDf)
+ val allBytesData = bytesDf.collect()
+
+ val columnFamilies = allBytesData.map(_.getString(3)).distinct.sorted
+
+ // Verify countState column family exists
+ assert(columnFamilies.toSet ==
+ Set("countState", "itemsList", "$rowCounter_itemsList", "itemsMap"))
+
+ // Define schemas for each column family based on provided schema info
+ val groupByKeySchema = StructType(Array(
+ StructField("value", StringType, nullable = true)
+ ))
+ val countStateValueSchema = StructType(Array(
+ StructField("value", LongType, nullable = false)
+ ))
+ val itemsListValueSchema = StructType(Array(
+ StructField("value", StringType, nullable = true)
+ ))
+ val rowCounterValueSchema = StructType(Array(
+ StructField("count", LongType, nullable = true)
+ ))
+ val itemsMapKeySchema = StructType(Array(
+ StructField("key", groupByKeySchema),
+ StructField("user_map_key", groupByKeySchema, nullable = true)
+ ))
+ val itemsMapValueSchema = StructType(Array(
+ StructField("user_map_value", IntegerType, nullable = true)
+ ))
+
+ // Validate countState
+ readAndValidateStateVar(
+ tempDir.getAbsolutePath, allBytesData,
+ stateVarName = "countState", groupByKeySchema, countStateValueSchema)
+
+ // Validate itemsList
+ readAndValidateStateVar(
+ tempDir.getAbsolutePath, allBytesData,
+ stateVarName = "itemsList", groupByKeySchema, itemsListValueSchema,
+ extraOptions = Map(StateSourceOptions.FLATTEN_COLLECTION_TYPES ->
"true"),
+ selectExprs = Seq("partition_id", "key", "list_element"))
+
+ // Validate $rowCounter_itemsList - intentionally reuses countState's
data
+ val countStateNormalDf = getNormalReadDf(tempDir.getAbsolutePath,
Option("countState"))
+ compareNormalAndBytesData(
+ countStateNormalDf.collect(),
+ allBytesData,
+ "$rowCounter_itemsList",
+ groupByKeySchema,
+ rowCounterValueSchema)
+
+ // Validate itemsMap
+ readAndValidateStateVar(
+ tempDir.getAbsolutePath, allBytesData,
+ stateVarName = "itemsMap", itemsMapKeySchema, itemsMapValueSchema,
+ selectExprs = Seq("partition_id", "STRUCT(key, user_map_key) AS KEY",
+ "user_map_value AS value"))
+ }
+ }
+
+ test("SPARK-54419: read all column families with event time timers") {
+ withTempDir { tempDir =>
+ val inputData = MemoryStream[(String, Long)]
+ val result = inputData.toDS()
+ .select(col("_1").as("key"),
timestamp_seconds(col("_2")).as("eventTime"))
+ .withWatermark("eventTime", "10 seconds")
+ .as[(String, Timestamp)]
+ .groupByKey(_._1)
+ .transformWithState(
+ new EventTimeTimerProcessor(),
+ TimeMode.EventTime(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath),
+ AddData(inputData, ("a", 1L), ("b", 2L), ("c", 3L)),
+ CheckLastBatch(("a", "1"), ("b", "1"), ("c", "1")),
+ StopStream
+ )
+
+ validateTimerColumnFamilies(tempDir.getAbsolutePath, "event")
+ }
+ }
+
+ test("SPARK-54419: read all column families with processing time timers") {
+ withTempDir { tempDir =>
+ val clock = new StreamManualClock
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new
RunningCountStatefulProcessorWithProcTimeTimer(),
+ TimeMode.ProcessingTime(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath,
+ trigger = Trigger.ProcessingTime("1 second"),
+ triggerClock = clock),
+ AddData(inputData, "a"),
+ AdvanceManualClock(1 * 1000),
+ CheckNewAnswer(("a", "1")),
+ StopStream
+ )
+
+ validateTimerColumnFamilies(tempDir.getAbsolutePath, "proc")
+ }
+ }
+
+ test("SPARK-54419: transformWithState with list state and TTL") {
+ withTempDir { tempDir =>
+ val clock = new StreamManualClock
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new ListStateTTLProcessor(),
+ TimeMode.ProcessingTime(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath,
+ trigger = Trigger.ProcessingTime("1 second"),
+ triggerClock = clock),
+ AddData(inputData, "a", "b", "a"),
+ AdvanceManualClock(1 * 1000),
+ CheckNewAnswer(("a", "2"), ("b", "1")),
+ StopStream
+ )
+
+ val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+ validateBytesReadDfSchema(bytesDf)
+
+ val allBytesData = bytesDf.collect()
+ val columnFamilies = allBytesData.map(_.getString(3)).distinct.sorted
+
+ assert(columnFamilies.toSet ==
+ Set("listState", "$ttl_listState", "$min_listState",
"$count_listState"))
+
+ // Define schemas for list state with TTL column families
+ val groupByKeySchema = StructType(Array(
+ StructField("value", StringType, nullable = true)
+ ))
+ val listStateValueSchema = StructType(Array(
+ StructField("value", StructType(Array(
+ StructField("value", StringType, nullable = true)
+ )), nullable = false),
+ StructField("ttlExpirationMs", LongType, nullable = false)
+ ))
+
+ val listStateNormalDf = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+ .option(StateSourceOptions.STATE_VAR_NAME, "listState")
+ .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, "true")
+ .load()
+ .selectExpr("partition_id", "key", "list_element")
+
+ compareNormalAndBytesData(
+ listStateNormalDf.collect(),
+ allBytesData,
+ "listState",
+ groupByKeySchema,
+ listStateValueSchema)
+ val dummyValueSchema = StructType(Array(StructField("__dummy__",
NullType)))
+ val ttlIndexKeySchema = StructType(Array(
+ StructField("expirationMs", LongType, nullable = false),
+ StructField("elementKey", groupByKeySchema)
+ ))
+ val minExpiryValueSchema = StructType(Array(
+ StructField("minExpiry", LongType)
+ ))
+ val countValueSchema = StructType(Array(
+ StructField("count", LongType)
+ ))
+ val columnFamilyAndKeyValueSchema = Seq(
+ ("$ttl_listState", ttlIndexKeySchema, dummyValueSchema),
+ ("$min_listState", groupByKeySchema, minExpiryValueSchema),
+ ("$count_listState", groupByKeySchema, countValueSchema)
+ )
+ columnFamilyAndKeyValueSchema.foreach(pair => {
+ val normalDf = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+ .option(StateSourceOptions.STATE_VAR_NAME, pair._1)
+ .load()
+ .selectExpr("partition_id", "key", "value")
+
+ compareNormalAndBytesData(
+ normalDf.collect(),
+ allBytesData,
+ pair._1,
+ pair._2,
+ pair._3)
+ }
+ )
+ }
+ }
+
+ def testStreamStreamJoinV3(): Unit = {
Review Comment:
They are not doing the exact same thing. testStreamStreamJoinV1AndV2 is
querying bytesDF separately for each store name, while testStreamStreamV3 query
1 bytesDF for all col family.
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala:
##########
@@ -533,4 +643,399 @@ class StatePartitionAllColumnFamiliesReaderSuite extends
StateDataSourceTestBase
}
}
}
+
+ test("SPARK-54419: transformWithState with multiple column families") {
+ withTempDir { tempDir =>
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new MultiStateVarProcessor(),
+ TimeMode.None(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath),
+ AddData(inputData, "a", "b", "a"),
+ CheckNewAnswer(("a", "2"), ("b", "1")),
+ AddData(inputData, "b", "c"),
+ CheckNewAnswer(("b", "2"), ("c", "1")),
+ StopStream
+ )
+
+ // Read all column families using internalOnlyReadAllColumnFamilies
+ val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+ validateBytesReadDfSchema(bytesDf)
+ val allBytesData = bytesDf.collect()
+
+ val columnFamilies = allBytesData.map(_.getString(3)).distinct.sorted
+
+ // Verify countState column family exists
+ assert(columnFamilies.toSet ==
+ Set("countState", "itemsList", "$rowCounter_itemsList", "itemsMap"))
+
+ // Define schemas for each column family based on provided schema info
+ val groupByKeySchema = StructType(Array(
+ StructField("value", StringType, nullable = true)
+ ))
+ val countStateValueSchema = StructType(Array(
+ StructField("value", LongType, nullable = false)
+ ))
+ val itemsListValueSchema = StructType(Array(
+ StructField("value", StringType, nullable = true)
+ ))
+ val rowCounterValueSchema = StructType(Array(
+ StructField("count", LongType, nullable = true)
+ ))
+ val itemsMapKeySchema = StructType(Array(
+ StructField("key", groupByKeySchema),
+ StructField("user_map_key", groupByKeySchema, nullable = true)
+ ))
+ val itemsMapValueSchema = StructType(Array(
+ StructField("user_map_value", IntegerType, nullable = true)
+ ))
+
+ // Validate countState
+ readAndValidateStateVar(
+ tempDir.getAbsolutePath, allBytesData,
+ stateVarName = "countState", groupByKeySchema, countStateValueSchema)
+
+ // Validate itemsList
+ readAndValidateStateVar(
+ tempDir.getAbsolutePath, allBytesData,
+ stateVarName = "itemsList", groupByKeySchema, itemsListValueSchema,
+ extraOptions = Map(StateSourceOptions.FLATTEN_COLLECTION_TYPES ->
"true"),
+ selectExprs = Seq("partition_id", "key", "list_element"))
+
+ // Validate $rowCounter_itemsList - intentionally reuses countState's
data
+ val countStateNormalDf = getNormalReadDf(tempDir.getAbsolutePath,
Option("countState"))
+ compareNormalAndBytesData(
+ countStateNormalDf.collect(),
+ allBytesData,
+ "$rowCounter_itemsList",
+ groupByKeySchema,
+ rowCounterValueSchema)
+
+ // Validate itemsMap
+ readAndValidateStateVar(
+ tempDir.getAbsolutePath, allBytesData,
+ stateVarName = "itemsMap", itemsMapKeySchema, itemsMapValueSchema,
+ selectExprs = Seq("partition_id", "STRUCT(key, user_map_key) AS KEY",
+ "user_map_value AS value"))
+ }
+ }
+
+ test("SPARK-54419: read all column families with event time timers") {
+ withTempDir { tempDir =>
+ val inputData = MemoryStream[(String, Long)]
+ val result = inputData.toDS()
+ .select(col("_1").as("key"),
timestamp_seconds(col("_2")).as("eventTime"))
+ .withWatermark("eventTime", "10 seconds")
+ .as[(String, Timestamp)]
+ .groupByKey(_._1)
+ .transformWithState(
+ new EventTimeTimerProcessor(),
+ TimeMode.EventTime(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath),
+ AddData(inputData, ("a", 1L), ("b", 2L), ("c", 3L)),
+ CheckLastBatch(("a", "1"), ("b", "1"), ("c", "1")),
+ StopStream
+ )
+
+ validateTimerColumnFamilies(tempDir.getAbsolutePath, "event")
+ }
+ }
+
+ test("SPARK-54419: read all column families with processing time timers") {
+ withTempDir { tempDir =>
+ val clock = new StreamManualClock
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new
RunningCountStatefulProcessorWithProcTimeTimer(),
+ TimeMode.ProcessingTime(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath,
+ trigger = Trigger.ProcessingTime("1 second"),
+ triggerClock = clock),
+ AddData(inputData, "a"),
+ AdvanceManualClock(1 * 1000),
+ CheckNewAnswer(("a", "1")),
+ StopStream
+ )
+
+ validateTimerColumnFamilies(tempDir.getAbsolutePath, "proc")
+ }
+ }
+
+ test("SPARK-54419: transformWithState with list state and TTL") {
+ withTempDir { tempDir =>
+ val clock = new StreamManualClock
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new ListStateTTLProcessor(),
+ TimeMode.ProcessingTime(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath,
+ trigger = Trigger.ProcessingTime("1 second"),
+ triggerClock = clock),
+ AddData(inputData, "a", "b", "a"),
+ AdvanceManualClock(1 * 1000),
+ CheckNewAnswer(("a", "2"), ("b", "1")),
+ StopStream
+ )
+
+ val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+ validateBytesReadDfSchema(bytesDf)
+
+ val allBytesData = bytesDf.collect()
+ val columnFamilies = allBytesData.map(_.getString(3)).distinct.sorted
+
+ assert(columnFamilies.toSet ==
+ Set("listState", "$ttl_listState", "$min_listState",
"$count_listState"))
+
+ // Define schemas for list state with TTL column families
+ val groupByKeySchema = StructType(Array(
+ StructField("value", StringType, nullable = true)
+ ))
+ val listStateValueSchema = StructType(Array(
+ StructField("value", StructType(Array(
+ StructField("value", StringType, nullable = true)
+ )), nullable = false),
+ StructField("ttlExpirationMs", LongType, nullable = false)
+ ))
+
+ val listStateNormalDf = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+ .option(StateSourceOptions.STATE_VAR_NAME, "listState")
+ .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, "true")
+ .load()
+ .selectExpr("partition_id", "key", "list_element")
+
+ compareNormalAndBytesData(
+ listStateNormalDf.collect(),
+ allBytesData,
+ "listState",
+ groupByKeySchema,
+ listStateValueSchema)
+ val dummyValueSchema = StructType(Array(StructField("__dummy__",
NullType)))
+ val ttlIndexKeySchema = StructType(Array(
+ StructField("expirationMs", LongType, nullable = false),
+ StructField("elementKey", groupByKeySchema)
+ ))
+ val minExpiryValueSchema = StructType(Array(
+ StructField("minExpiry", LongType)
+ ))
+ val countValueSchema = StructType(Array(
+ StructField("count", LongType)
+ ))
+ val columnFamilyAndKeyValueSchema = Seq(
+ ("$ttl_listState", ttlIndexKeySchema, dummyValueSchema),
+ ("$min_listState", groupByKeySchema, minExpiryValueSchema),
+ ("$count_listState", groupByKeySchema, countValueSchema)
+ )
+ columnFamilyAndKeyValueSchema.foreach(pair => {
+ val normalDf = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+ .option(StateSourceOptions.STATE_VAR_NAME, pair._1)
+ .load()
+ .selectExpr("partition_id", "key", "value")
+
+ compareNormalAndBytesData(
+ normalDf.collect(),
+ allBytesData,
+ pair._1,
+ pair._2,
+ pair._3)
+ }
+ )
+ }
+ }
+
+ def testStreamStreamJoinV3(): Unit = {
Review Comment:
They are not doing the exact same thing. testStreamStreamJoinV1AndV2 is
querying bytesDF separately for each store name, while testStreamStreamV3 query
1 bytesDF for all col family.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]