micheal-o commented on code in PR #53386:
URL: https://github.com/apache/spark/pull/53386#discussion_r2636445004
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala:
##########
@@ -457,3 +458,340 @@ case class SessionUpdate(
durationMs: Long,
numEvents: Int,
expired: Boolean)
+
+case class ColumnFamilyMetadata(
+ keySchema: StructType,
+ valueSchema: StructType,
+ encoderSpec: KeyStateEncoderSpec,
+ useMultipleValuePerKey: Boolean = false)
+
+// Utility for runCompositeKeyStreamingAggregationQuery
+// todo: Move runCompositeKeyStreamingAggregationQuery to this class
+object CompositeKeyAggregationTestUtils {
+ def getSchemas(stateVersion: Int): (StructType, StructType) = {
+ val metadata = getSchemasWithMetadata(stateVersion)
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getSchemasWithMetadata(stateVersion: Int): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("fruit", StringType, nullable = true)
+ ))
+
+ val valueSchema = if (stateVersion == 1) {
+ // State version 1 includes key columns in the value
+ StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("fruit", StringType, nullable = true),
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ } else {
+ // State version 2 excludes key columns from the value
+ StructType(Array(
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ }
+
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+// Utility for run DropDuplicatesQueries
+// todo: Move run DropDuplicatesQueries to this class
+object DropDuplicatesTestUtils {
+ def getDropDuplicatesSchemas(): (StructType, StructType) = {
+ val metadata = getDropDuplicatesSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getDropDuplicatesSchemasWithMetadata(): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("value", IntegerType, nullable = false),
+ StructField("eventTime", TimestampType)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("__dummy__", NullType)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+
+ def getDropDuplicatesWithColumnSchemas(): (StructType, StructType) = {
+ val metadata = getDropDuplicatesWithColumnSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getDropDuplicatesWithColumnSchemasWithMetadata(): ColumnFamilyMetadata =
{
+ val keySchema = StructType(Array(
+ StructField("col1", StringType, nullable = true)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("__dummy__", NullType, nullable = true)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+
+ def getDropDuplicatesWithinWatermarkSchemas(): (StructType, StructType) = {
+ val metadata = getDropDuplicatesWithinWatermarkSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getDropDuplicatesWithinWatermarkSchemasWithMetadata():
ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("_1", StringType, nullable = true)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("expiresAtMicros", LongType, nullable = false)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+/**
+ * Test utility object providing schema definitions for simple streaming
aggregation.
+ * Used by StatePartitionAllColumnFamiliesWriterSuite and
StatePartitionAllColumnFamiliesReaderSuite
+ * to eliminate code duplication.
+ */
+object SimpleAggregationTestUtils {
+ def getSchemas(stateVersion: Int): (StructType, StructType) = {
+ val metadata = getSchemasWithMetadata(stateVersion)
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ /**
+ * @param stateVersion The state format version:
+ * @return ColumnFamilyMetadata including schema and KeyEncoderSpec
+ */
+ def getSchemasWithMetadata(stateVersion: Int): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false)
+ ))
+
+ val valueSchema = if (stateVersion == 1) {
+ // State version 1 includes key columns in the value
+ StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ } else {
+ // State version 2 excludes key columns from the value
+ StructType(Array(
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ }
+
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+/**
+ * Test utility object providing schema definitions for flatMapGroupsWithState.
+ */
+object FlatMapGroupsWithStateTestUtils {
+
+ /**
+ * @param stateVersion The state format version:
Review Comment:
ditto
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala:
##########
@@ -457,3 +458,340 @@ case class SessionUpdate(
durationMs: Long,
numEvents: Int,
expired: Boolean)
+
+case class ColumnFamilyMetadata(
+ keySchema: StructType,
+ valueSchema: StructType,
+ encoderSpec: KeyStateEncoderSpec,
+ useMultipleValuePerKey: Boolean = false)
+
+// Utility for runCompositeKeyStreamingAggregationQuery
+// todo: Move runCompositeKeyStreamingAggregationQuery to this class
+object CompositeKeyAggregationTestUtils {
+ def getSchemas(stateVersion: Int): (StructType, StructType) = {
+ val metadata = getSchemasWithMetadata(stateVersion)
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getSchemasWithMetadata(stateVersion: Int): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("fruit", StringType, nullable = true)
+ ))
+
+ val valueSchema = if (stateVersion == 1) {
+ // State version 1 includes key columns in the value
+ StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("fruit", StringType, nullable = true),
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ } else {
+ // State version 2 excludes key columns from the value
+ StructType(Array(
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ }
+
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+// Utility for run DropDuplicatesQueries
+// todo: Move run DropDuplicatesQueries to this class
+object DropDuplicatesTestUtils {
+ def getDropDuplicatesSchemas(): (StructType, StructType) = {
+ val metadata = getDropDuplicatesSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getDropDuplicatesSchemasWithMetadata(): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("value", IntegerType, nullable = false),
+ StructField("eventTime", TimestampType)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("__dummy__", NullType)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+
+ def getDropDuplicatesWithColumnSchemas(): (StructType, StructType) = {
+ val metadata = getDropDuplicatesWithColumnSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getDropDuplicatesWithColumnSchemasWithMetadata(): ColumnFamilyMetadata =
{
+ val keySchema = StructType(Array(
+ StructField("col1", StringType, nullable = true)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("__dummy__", NullType, nullable = true)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+
+ def getDropDuplicatesWithinWatermarkSchemas(): (StructType, StructType) = {
+ val metadata = getDropDuplicatesWithinWatermarkSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getDropDuplicatesWithinWatermarkSchemasWithMetadata():
ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("_1", StringType, nullable = true)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("expiresAtMicros", LongType, nullable = false)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+/**
+ * Test utility object providing schema definitions for simple streaming
aggregation.
+ * Used by StatePartitionAllColumnFamiliesWriterSuite and
StatePartitionAllColumnFamiliesReaderSuite
+ * to eliminate code duplication.
+ */
+object SimpleAggregationTestUtils {
+ def getSchemas(stateVersion: Int): (StructType, StructType) = {
+ val metadata = getSchemasWithMetadata(stateVersion)
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ /**
+ * @param stateVersion The state format version:
+ * @return ColumnFamilyMetadata including schema and KeyEncoderSpec
+ */
+ def getSchemasWithMetadata(stateVersion: Int): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false)
+ ))
+
+ val valueSchema = if (stateVersion == 1) {
+ // State version 1 includes key columns in the value
+ StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ } else {
+ // State version 2 excludes key columns from the value
+ StructType(Array(
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ }
+
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+/**
+ * Test utility object providing schema definitions for flatMapGroupsWithState.
+ */
+object FlatMapGroupsWithStateTestUtils {
+
+ /**
+ * @param stateVersion The state format version:
+ * @return A tuple of (keySchema, valueSchema)
+ */
+ def getSchemas(stateVersion: Int): (StructType, StructType) = {
+ val metadata = getSchemasWithMetadata(stateVersion)
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ /**
+ * @param stateVersion The state format version
+ * @return ColumnFamilyMetadata with schema and KeyEncoderSpec
+ */
+ def getSchemasWithMetadata(stateVersion: Int): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("value", StringType, nullable = true)
+ ))
+
+ val valueSchema = if (stateVersion == 1) {
+ // State version 1: Flat structure
+ StructType(Array(
+ StructField("numEvents", IntegerType, nullable = false),
+ StructField("startTimestampMs", LongType, nullable = false),
+ StructField("endTimestampMs", LongType, nullable = false),
+ StructField("timeoutTimestamp", IntegerType, nullable = false)
+ ))
+ } else {
+ // State version 2: Nested structure with groupState wrapper
+ StructType(Array(
+ StructField("groupState", StructType(Array(
+ StructField("numEvents", IntegerType, nullable = false),
+ StructField("startTimestampMs", LongType, nullable = false),
+ StructField("endTimestampMs", LongType, nullable = false)
+ )), nullable = false),
+ StructField("timeoutTimestamp", LongType, nullable = false)
+ ))
+ }
+
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+/**
+ * Test utility object providing schema definitions for
runSessionWindowAggregationQuery
+ */
+object SessionWindowTestUtils {
+
+ /**
+ * @return A tuple of (keySchema, valueSchema)
+ */
+ def getSchemas(): (StructType, StructType) = {
+ val metadata = getSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ /**
+ * @return ColumnFamilyMetadata with schema and KeyEncoderSpec
+ */
+ def getSchemasWithMetadata(): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("sessionId", StringType, nullable = false),
+ StructField("sessionStartTime", TimestampType, nullable = false)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("session_window", StructType(Array(
+ StructField("start", TimestampType),
+ StructField("end", TimestampType)
+ )), nullable = false),
+ StructField("sessionId", StringType, nullable = false),
+ StructField("count", LongType, nullable = false)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+/**
+ * Test utility object providing schema definitions and constants for
runStreamStreamJoinQuery
+ */
+object StreamStreamJoinTestUtils {
+ // Column family names for keyToNumValues stores
+ val KEY_TO_NUM_VALUES_LEFT = "left-keyToNumValues"
Review Comment:
No need for this, you can call `allStateStoreNames` in
`SymmetricHashJoinStateManager` and it will return the list of names
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala:
##########
@@ -47,19 +54,41 @@ class StatePartitionAllColumnFamiliesWriter(
operatorId: Int,
storeName: String,
currentBatchId: Long,
- columnFamilyToSchemaMap: HashMap[String, StateStoreColFamilySchema]) {
+ colFamilyToWriterInfoMap: Map[String,
StatePartitionWriterColumnFamilyInfo],
+ operatorName: String,
+ schemaProviderOpt: Option[StateSchemaProvider],
+ sqlConf: Map[String, String]) {
+
+ private def isJoinV3Operator(
+ operatorName: String, sqlConf: Map[String, String]): Boolean = {
Review Comment:
ditto
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala:
##########
@@ -47,19 +54,41 @@ class StatePartitionAllColumnFamiliesWriter(
operatorId: Int,
storeName: String,
currentBatchId: Long,
- columnFamilyToSchemaMap: HashMap[String, StateStoreColFamilySchema]) {
+ colFamilyToWriterInfoMap: Map[String,
StatePartitionWriterColumnFamilyInfo],
+ operatorName: String,
+ schemaProviderOpt: Option[StateSchemaProvider],
+ sqlConf: Map[String, String]) {
+
+ private def isJoinV3Operator(
+ operatorName: String, sqlConf: Map[String, String]): Boolean = {
+ operatorName == StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME &&
+ sqlConf(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key) == "3"
+ }
+
private val defaultSchema = {
- columnFamilyToSchemaMap.getOrElse(
- StateStore.DEFAULT_COL_FAMILY_NAME,
- throw new IllegalArgumentException(
- s"Column family ${StateStore.DEFAULT_COL_FAMILY_NAME} not found in
schema map")
- )
+ colFamilyToWriterInfoMap.get(StateStore.DEFAULT_COL_FAMILY_NAME) match {
+ case Some(info) => info.schema
+ case None =>
+ assert(isJoinV3Operator(operatorName, sqlConf),
Review Comment:
nit: add one line comment, that Join v3 doesn't write default CF schema
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala:
##########
@@ -47,19 +54,41 @@ class StatePartitionAllColumnFamiliesWriter(
operatorId: Int,
storeName: String,
currentBatchId: Long,
- columnFamilyToSchemaMap: HashMap[String, StateStoreColFamilySchema]) {
+ colFamilyToWriterInfoMap: Map[String,
StatePartitionWriterColumnFamilyInfo],
+ operatorName: String,
+ schemaProviderOpt: Option[StateSchemaProvider],
+ sqlConf: Map[String, String]) {
Review Comment:
nit: take in `SQLConf`
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala:
##########
@@ -457,3 +458,340 @@ case class SessionUpdate(
durationMs: Long,
numEvents: Int,
expired: Boolean)
+
+case class ColumnFamilyMetadata(
+ keySchema: StructType,
+ valueSchema: StructType,
+ encoderSpec: KeyStateEncoderSpec,
+ useMultipleValuePerKey: Boolean = false)
+
+// Utility for runCompositeKeyStreamingAggregationQuery
+// todo: Move runCompositeKeyStreamingAggregationQuery to this class
+object CompositeKeyAggregationTestUtils {
+ def getSchemas(stateVersion: Int): (StructType, StructType) = {
+ val metadata = getSchemasWithMetadata(stateVersion)
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getSchemasWithMetadata(stateVersion: Int): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("fruit", StringType, nullable = true)
+ ))
+
+ val valueSchema = if (stateVersion == 1) {
+ // State version 1 includes key columns in the value
+ StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("fruit", StringType, nullable = true),
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ } else {
+ // State version 2 excludes key columns from the value
+ StructType(Array(
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ }
+
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+// Utility for run DropDuplicatesQueries
+// todo: Move run DropDuplicatesQueries to this class
+object DropDuplicatesTestUtils {
+ def getDropDuplicatesSchemas(): (StructType, StructType) = {
+ val metadata = getDropDuplicatesSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getDropDuplicatesSchemasWithMetadata(): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("value", IntegerType, nullable = false),
+ StructField("eventTime", TimestampType)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("__dummy__", NullType)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+
+ def getDropDuplicatesWithColumnSchemas(): (StructType, StructType) = {
+ val metadata = getDropDuplicatesWithColumnSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getDropDuplicatesWithColumnSchemasWithMetadata(): ColumnFamilyMetadata =
{
+ val keySchema = StructType(Array(
+ StructField("col1", StringType, nullable = true)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("__dummy__", NullType, nullable = true)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+
+ def getDropDuplicatesWithinWatermarkSchemas(): (StructType, StructType) = {
+ val metadata = getDropDuplicatesWithinWatermarkSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getDropDuplicatesWithinWatermarkSchemasWithMetadata():
ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("_1", StringType, nullable = true)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("expiresAtMicros", LongType, nullable = false)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+/**
+ * Test utility object providing schema definitions for simple streaming
aggregation.
+ * Used by StatePartitionAllColumnFamiliesWriterSuite and
StatePartitionAllColumnFamiliesReaderSuite
+ * to eliminate code duplication.
+ */
+object SimpleAggregationTestUtils {
+ def getSchemas(stateVersion: Int): (StructType, StructType) = {
+ val metadata = getSchemasWithMetadata(stateVersion)
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ /**
+ * @param stateVersion The state format version:
+ * @return ColumnFamilyMetadata including schema and KeyEncoderSpec
+ */
+ def getSchemasWithMetadata(stateVersion: Int): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false)
+ ))
+
+ val valueSchema = if (stateVersion == 1) {
+ // State version 1 includes key columns in the value
+ StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ } else {
+ // State version 2 excludes key columns from the value
+ StructType(Array(
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ }
+
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+/**
+ * Test utility object providing schema definitions for flatMapGroupsWithState.
+ */
+object FlatMapGroupsWithStateTestUtils {
+
+ /**
+ * @param stateVersion The state format version:
+ * @return A tuple of (keySchema, valueSchema)
+ */
+ def getSchemas(stateVersion: Int): (StructType, StructType) = {
+ val metadata = getSchemasWithMetadata(stateVersion)
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ /**
+ * @param stateVersion The state format version
+ * @return ColumnFamilyMetadata with schema and KeyEncoderSpec
+ */
+ def getSchemasWithMetadata(stateVersion: Int): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("value", StringType, nullable = true)
+ ))
+
+ val valueSchema = if (stateVersion == 1) {
+ // State version 1: Flat structure
+ StructType(Array(
+ StructField("numEvents", IntegerType, nullable = false),
+ StructField("startTimestampMs", LongType, nullable = false),
+ StructField("endTimestampMs", LongType, nullable = false),
+ StructField("timeoutTimestamp", IntegerType, nullable = false)
+ ))
+ } else {
+ // State version 2: Nested structure with groupState wrapper
+ StructType(Array(
+ StructField("groupState", StructType(Array(
+ StructField("numEvents", IntegerType, nullable = false),
+ StructField("startTimestampMs", LongType, nullable = false),
+ StructField("endTimestampMs", LongType, nullable = false)
+ )), nullable = false),
+ StructField("timeoutTimestamp", LongType, nullable = false)
+ ))
+ }
+
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+/**
+ * Test utility object providing schema definitions for
runSessionWindowAggregationQuery
+ */
+object SessionWindowTestUtils {
+
+ /**
+ * @return A tuple of (keySchema, valueSchema)
+ */
+ def getSchemas(): (StructType, StructType) = {
+ val metadata = getSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ /**
+ * @return ColumnFamilyMetadata with schema and KeyEncoderSpec
+ */
+ def getSchemasWithMetadata(): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("sessionId", StringType, nullable = false),
+ StructField("sessionStartTime", TimestampType, nullable = false)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("session_window", StructType(Array(
+ StructField("start", TimestampType),
+ StructField("end", TimestampType)
+ )), nullable = false),
+ StructField("sessionId", StringType, nullable = false),
+ StructField("count", LongType, nullable = false)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+/**
+ * Test utility object providing schema definitions and constants for
runStreamStreamJoinQuery
+ */
+object StreamStreamJoinTestUtils {
+ // Column family names for keyToNumValues stores
+ val KEY_TO_NUM_VALUES_LEFT = "left-keyToNumValues"
+ val KEY_TO_NUM_VALUES_RIGHT = "right-keyToNumValues"
+ val KEY_TO_NUM_VALUES_ALL: Seq[String] = Seq(
+ KEY_TO_NUM_VALUES_LEFT,
+ KEY_TO_NUM_VALUES_RIGHT
+ )
+
+ // Column family names for keyWithIndexToValue stores
+ val KEY_WITH_INDEX_LEFT = "left-keyWithIndexToValue"
+ val KEY_WITH_INDEX_RIGHT = "right-keyWithIndexToValue"
+ val KEY_WITH_INDEX_ALL: Seq[String] = Seq(
+ KEY_WITH_INDEX_LEFT,
+ KEY_WITH_INDEX_RIGHT
+ )
+
+ /**
+ * @return A tuple of (keySchema, valueSchema)
+ */
+ def getKeyToNumValuesSchemas(): (StructType, StructType) = {
+ val metadata = getKeyToNumValuesSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ /**
+ * @return ColumnFamilyMetadata with schema and KeyEncoderSpec
+ */
+ def getKeyToNumValuesSchemasWithMetadata(): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("key", IntegerType)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("value", LongType)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+
+ /**
+ * @param stateVersion The join state format version:
+ * @return A tuple of (keySchema, valueSchema)
+ */
+ def getKeyWithIndexToValueSchemas(stateVersion: Int): (StructType,
StructType) = {
+ val metadata = getKeyWithIndexToValueSchemasWithMetadata(stateVersion)
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ /**
+ * @param stateVersion The state format version
+ * @return ColumnFamilyMetadata with schema and KeyEncoderSpec
+ */
+ def getKeyWithIndexToValueSchemasWithMetadata(stateVersion: Int):
ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("key", IntegerType, nullable = false),
+ StructField("index", LongType)
+ ))
+
+ val valueSchema = if (stateVersion == 2 || stateVersion == 3) {
+ StructType(Array(
+ StructField("value", IntegerType, nullable = false),
+ StructField("time", TimestampType, nullable = false),
+ StructField("matched", BooleanType)
+ ))
+ } else {
+ StructType(Array(
+ StructField("value", IntegerType, nullable = false),
+ StructField("time", TimestampType, nullable = false)
+ ))
+ }
+
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+
+ /**
+ * Returns all schemas for stream-stream join V3 (multi-column family) in
legacy 2-tuple format.
+ * V3 uses a single state store with multiple column families instead of
separate stores.
+ *
+ * @return Map of column family name to (keySchema, valueSchema)
+ */
+ def getJoinV3ColumnSchemaMap(): Map[String, (StructType, StructType)] = {
+ getJoinV3ColumnSchemaMapWithMetadata().view.mapValues { metadata =>
+ (metadata.keySchema, metadata.valueSchema)
+ }.toMap
+ }
+
+ /**
+ * @return Map of column family name to ColumnFamilyMetadata
+ */
+ def getJoinV3ColumnSchemaMapWithMetadata(): Map[String,
ColumnFamilyMetadata] = {
+ val (keyToNumKeySchema, keyToNumValueSchema) = getKeyToNumValuesSchemas()
+ val (keyWithIndexKeySchema, keyWithIndexValueSchema) =
getKeyWithIndexToValueSchemas(3)
+
+ val keyToNumEncoderSpec = NoPrefixKeyStateEncoderSpec(keyToNumKeySchema)
+ val keyWithIndexEncoderSpec =
NoPrefixKeyStateEncoderSpec(keyWithIndexKeySchema)
+
+ Map(
+ KEY_TO_NUM_VALUES_LEFT -> ColumnFamilyMetadata(
Review Comment:
can use `getStateStoreName(Left, KeyToNumValuesType)` in join manager. Same
for others
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/TimerTestUtils.scala:
##########
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state.utils
+
+import java.sql.Timestamp
+
+import org.apache.spark.sql.Encoders
+import
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.timers.TimerStateUtils
+import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec,
NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec,
RangeKeyScanStateEncoderSpec, StateStore}
+import org.apache.spark.sql.streaming.{OutputMode, StatefulProcessor,
TimeMode, TimerValues, TTLConfig, ValueState}
+import org.apache.spark.sql.types.{BinaryType, LongType, NullType, StringType,
StructField, StructType}
+
+class EventTimeTimerProcessor
+ extends StatefulProcessor[String, (String, Timestamp), (String, String)] {
+ @transient var _valueState: ValueState[Long] = _
+
+ override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
+ _valueState = getHandle.getValueState("countState", Encoders.scalaLong,
TTLConfig.NONE)
+ }
+
+ override def handleInputRows(
+ key: String,
+ rows: Iterator[(String, Timestamp)],
+ timerValues: TimerValues): Iterator[(String, String)] = {
+ var maxTimestamp = 0L
+ var rowCount = 0
+ rows.foreach { case (_, timestamp) =>
+ maxTimestamp = Math.max(maxTimestamp, timestamp.getTime)
+ rowCount += 1
+ }
+
+ val count = Option(_valueState.get()).getOrElse(0L) + rowCount
+ _valueState.update(count)
+
+ // Register an event time timer
+ if (maxTimestamp > 0) {
+ getHandle.registerTimer(maxTimestamp + 5000)
+ }
+
+ Iterator((key, count.toString))
+ }
+}
+
+/**
+ * Test utility providing schema definitions and constants for
EventTimeTimerProcessor
+ * and RunningCountStatefulProcessorWithProcTimeTimer in
TransformWithStateSuite
+ */
+object TimerTestUtils {
+ case class ColumnFamilyMetadata(
+ keySchema: StructType,
+ valueSchema: StructType,
+ encoderSpec: KeyStateEncoderSpec,
+ useMultipleValuePerKey: Boolean = false)
+
+ /**
+ * Returns the grouping key schema and state value schema for a simple count
state.
+ * This is commonly used with timer tests where state tracks counts.
+ *
+ * @return A tuple of (keySchema, valueSchema)
+ */
+ def getCountStateSchemas(): (StructType, StructType) = {
+ val groupByKeySchema = StructType(Array(
+ StructField("key", StringType, nullable = true)
+ ))
+ val stateValueSchema = StructType(Array(
+ StructField("value", LongType, nullable = true)
+ ))
+ (groupByKeySchema, stateValueSchema)
+ }
+
+ /**
+ * Returns schemas for timer-related column families
+ *
+ * @param groupingKeySchema The schema for the grouping key
+ * @return A tuple of (keyToTimestampKeySchema, timestampToKeyKeySchema)
+ */
+ def getTimerKeySchemas(groupingKeySchema: StructType): (StructType,
StructType) = {
+ val keyToTimestampKeySchema = StructType(Array(
+ StructField("key", groupingKeySchema),
+ StructField("expiryTimestampMs", LongType, nullable = false)
+ ))
+ val timestampToKeyKeySchema = StructType(Array(
+ StructField("expiryTimestampMs", LongType, nullable = false),
+ StructField("key", groupingKeySchema)
+ ))
+
+ (keyToTimestampKeySchema, timestampToKeyKeySchema)
+ }
+
+ /**
+ * Returns complete metadata for timer-related column families including
encoder specs.
+ * Used for tests with timers and a count state.
+ * @param timeMode The time mode (EventTime or ProcessingTime)
+ * @return A map from column family name to ColumnFamilyMetadata
+ */
+ def getTimerConfigsForCountState(timeMode: TimeMode): Map[String,
ColumnFamilyMetadata] = {
+ val (groupByKeySchema, stateValueSchema) = getCountStateSchemas()
+ val (keyToTimestampKeySchema, timestampToKeyKeySchema) =
getTimerKeySchemas(groupByKeySchema)
+ val defaultValueSchema = StructType(Array(StructField("value", BinaryType,
nullable = true)))
+
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(groupByKeySchema)
+ val keyToTimestampEncoderSpec =
PrefixKeyScanStateEncoderSpec(keyToTimestampKeySchema, 1)
+ val timestampToKeyEncoderSpec =
RangeKeyScanStateEncoderSpec(timestampToKeyKeySchema, Seq(0))
+
+ val (keyToTimestampCF, timestampToKeyCF) =
+ TimerStateUtils.getTimerStateVarNames(timeMode.toString)
+
+ val dummyValueSchema = StructType(Array(StructField("__dummy__",
NullType)))
+
+ Map(
+ StateStore.DEFAULT_COL_FAMILY_NAME -> ColumnFamilyMetadata(
+ groupByKeySchema, defaultValueSchema, encoderSpec),
+ "countState" -> ColumnFamilyMetadata(
+ groupByKeySchema, stateValueSchema, encoderSpec),
+ keyToTimestampCF -> ColumnFamilyMetadata(
+ keyToTimestampKeySchema, dummyValueSchema, keyToTimestampEncoderSpec),
+ timestampToKeyCF -> ColumnFamilyMetadata(
+ timestampToKeyKeySchema, dummyValueSchema, timestampToKeyEncoderSpec)
+ )
+ }
+
+ /**
+ * Returns select expressions for timer column families with standardized
column order.
+ *
+ * @param columnFamilyName The timer column family name
+ * @return Select expressions as a Seq of strings with order (partition_id,
key, value)
+ */
+ def getTimerSelectExpressions(columnFamilyName: String): Seq[String] = {
+ if (columnFamilyName.endsWith("_keyToTimestamp")) {
+ Seq("partition_id",
+ "STRUCT(key AS groupingKey, expiration_timestamp_ms AS key)",
+ "NULL AS value")
+ } else if (columnFamilyName.endsWith("_timestampToKey")) {
Review Comment:
ditto
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala:
##########
@@ -457,3 +458,340 @@ case class SessionUpdate(
durationMs: Long,
numEvents: Int,
expired: Boolean)
+
+case class ColumnFamilyMetadata(
+ keySchema: StructType,
+ valueSchema: StructType,
+ encoderSpec: KeyStateEncoderSpec,
+ useMultipleValuePerKey: Boolean = false)
+
+// Utility for runCompositeKeyStreamingAggregationQuery
+// todo: Move runCompositeKeyStreamingAggregationQuery to this class
+object CompositeKeyAggregationTestUtils {
+ def getSchemas(stateVersion: Int): (StructType, StructType) = {
+ val metadata = getSchemasWithMetadata(stateVersion)
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getSchemasWithMetadata(stateVersion: Int): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("fruit", StringType, nullable = true)
+ ))
+
+ val valueSchema = if (stateVersion == 1) {
+ // State version 1 includes key columns in the value
+ StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("fruit", StringType, nullable = true),
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ } else {
+ // State version 2 excludes key columns from the value
+ StructType(Array(
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ }
+
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+// Utility for run DropDuplicatesQueries
+// todo: Move run DropDuplicatesQueries to this class
+object DropDuplicatesTestUtils {
+ def getDropDuplicatesSchemas(): (StructType, StructType) = {
+ val metadata = getDropDuplicatesSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getDropDuplicatesSchemasWithMetadata(): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("value", IntegerType, nullable = false),
+ StructField("eventTime", TimestampType)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("__dummy__", NullType)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+
+ def getDropDuplicatesWithColumnSchemas(): (StructType, StructType) = {
+ val metadata = getDropDuplicatesWithColumnSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getDropDuplicatesWithColumnSchemasWithMetadata(): ColumnFamilyMetadata =
{
+ val keySchema = StructType(Array(
+ StructField("col1", StringType, nullable = true)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("__dummy__", NullType, nullable = true)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+
+ def getDropDuplicatesWithinWatermarkSchemas(): (StructType, StructType) = {
+ val metadata = getDropDuplicatesWithinWatermarkSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getDropDuplicatesWithinWatermarkSchemasWithMetadata():
ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("_1", StringType, nullable = true)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("expiresAtMicros", LongType, nullable = false)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+/**
+ * Test utility object providing schema definitions for simple streaming
aggregation.
+ * Used by StatePartitionAllColumnFamiliesWriterSuite and
StatePartitionAllColumnFamiliesReaderSuite
+ * to eliminate code duplication.
+ */
+object SimpleAggregationTestUtils {
+ def getSchemas(stateVersion: Int): (StructType, StructType) = {
+ val metadata = getSchemasWithMetadata(stateVersion)
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ /**
+ * @param stateVersion The state format version:
Review Comment:
nit: this description gives no valuable info. Can remove it
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala:
##########
@@ -457,3 +458,340 @@ case class SessionUpdate(
durationMs: Long,
numEvents: Int,
expired: Boolean)
+
+case class ColumnFamilyMetadata(
+ keySchema: StructType,
+ valueSchema: StructType,
+ encoderSpec: KeyStateEncoderSpec,
+ useMultipleValuePerKey: Boolean = false)
+
+// Utility for runCompositeKeyStreamingAggregationQuery
+// todo: Move runCompositeKeyStreamingAggregationQuery to this class
+object CompositeKeyAggregationTestUtils {
+ def getSchemas(stateVersion: Int): (StructType, StructType) = {
+ val metadata = getSchemasWithMetadata(stateVersion)
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getSchemasWithMetadata(stateVersion: Int): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("fruit", StringType, nullable = true)
+ ))
+
+ val valueSchema = if (stateVersion == 1) {
+ // State version 1 includes key columns in the value
+ StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("fruit", StringType, nullable = true),
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ } else {
+ // State version 2 excludes key columns from the value
+ StructType(Array(
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ }
+
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+// Utility for run DropDuplicatesQueries
+// todo: Move run DropDuplicatesQueries to this class
+object DropDuplicatesTestUtils {
+ def getDropDuplicatesSchemas(): (StructType, StructType) = {
+ val metadata = getDropDuplicatesSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getDropDuplicatesSchemasWithMetadata(): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("value", IntegerType, nullable = false),
+ StructField("eventTime", TimestampType)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("__dummy__", NullType)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+
+ def getDropDuplicatesWithColumnSchemas(): (StructType, StructType) = {
+ val metadata = getDropDuplicatesWithColumnSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getDropDuplicatesWithColumnSchemasWithMetadata(): ColumnFamilyMetadata =
{
+ val keySchema = StructType(Array(
+ StructField("col1", StringType, nullable = true)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("__dummy__", NullType, nullable = true)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+
+ def getDropDuplicatesWithinWatermarkSchemas(): (StructType, StructType) = {
+ val metadata = getDropDuplicatesWithinWatermarkSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getDropDuplicatesWithinWatermarkSchemasWithMetadata():
ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("_1", StringType, nullable = true)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("expiresAtMicros", LongType, nullable = false)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+/**
+ * Test utility object providing schema definitions for simple streaming
aggregation.
+ * Used by StatePartitionAllColumnFamiliesWriterSuite and
StatePartitionAllColumnFamiliesReaderSuite
+ * to eliminate code duplication.
+ */
+object SimpleAggregationTestUtils {
+ def getSchemas(stateVersion: Int): (StructType, StructType) = {
+ val metadata = getSchemasWithMetadata(stateVersion)
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ /**
+ * @param stateVersion The state format version:
+ * @return ColumnFamilyMetadata including schema and KeyEncoderSpec
+ */
+ def getSchemasWithMetadata(stateVersion: Int): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false)
+ ))
+
+ val valueSchema = if (stateVersion == 1) {
+ // State version 1 includes key columns in the value
+ StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ } else {
+ // State version 2 excludes key columns from the value
+ StructType(Array(
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ }
+
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+/**
+ * Test utility object providing schema definitions for flatMapGroupsWithState.
+ */
+object FlatMapGroupsWithStateTestUtils {
+
+ /**
+ * @param stateVersion The state format version:
+ * @return A tuple of (keySchema, valueSchema)
+ */
+ def getSchemas(stateVersion: Int): (StructType, StructType) = {
+ val metadata = getSchemasWithMetadata(stateVersion)
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ /**
+ * @param stateVersion The state format version
+ * @return ColumnFamilyMetadata with schema and KeyEncoderSpec
+ */
+ def getSchemasWithMetadata(stateVersion: Int): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("value", StringType, nullable = true)
+ ))
+
+ val valueSchema = if (stateVersion == 1) {
+ // State version 1: Flat structure
+ StructType(Array(
+ StructField("numEvents", IntegerType, nullable = false),
+ StructField("startTimestampMs", LongType, nullable = false),
+ StructField("endTimestampMs", LongType, nullable = false),
+ StructField("timeoutTimestamp", IntegerType, nullable = false)
+ ))
+ } else {
+ // State version 2: Nested structure with groupState wrapper
+ StructType(Array(
+ StructField("groupState", StructType(Array(
+ StructField("numEvents", IntegerType, nullable = false),
+ StructField("startTimestampMs", LongType, nullable = false),
+ StructField("endTimestampMs", LongType, nullable = false)
+ )), nullable = false),
+ StructField("timeoutTimestamp", LongType, nullable = false)
+ ))
+ }
+
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+/**
+ * Test utility object providing schema definitions for
runSessionWindowAggregationQuery
+ */
+object SessionWindowTestUtils {
+
+ /**
+ * @return A tuple of (keySchema, valueSchema)
Review Comment:
ditto
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala:
##########
@@ -457,3 +458,340 @@ case class SessionUpdate(
durationMs: Long,
numEvents: Int,
expired: Boolean)
+
+case class ColumnFamilyMetadata(
+ keySchema: StructType,
+ valueSchema: StructType,
+ encoderSpec: KeyStateEncoderSpec,
+ useMultipleValuePerKey: Boolean = false)
+
+// Utility for runCompositeKeyStreamingAggregationQuery
+// todo: Move runCompositeKeyStreamingAggregationQuery to this class
+object CompositeKeyAggregationTestUtils {
+ def getSchemas(stateVersion: Int): (StructType, StructType) = {
+ val metadata = getSchemasWithMetadata(stateVersion)
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getSchemasWithMetadata(stateVersion: Int): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("fruit", StringType, nullable = true)
+ ))
+
+ val valueSchema = if (stateVersion == 1) {
+ // State version 1 includes key columns in the value
+ StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("fruit", StringType, nullable = true),
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ } else {
+ // State version 2 excludes key columns from the value
+ StructType(Array(
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ }
+
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+// Utility for run DropDuplicatesQueries
+// todo: Move run DropDuplicatesQueries to this class
+object DropDuplicatesTestUtils {
+ def getDropDuplicatesSchemas(): (StructType, StructType) = {
+ val metadata = getDropDuplicatesSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getDropDuplicatesSchemasWithMetadata(): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("value", IntegerType, nullable = false),
+ StructField("eventTime", TimestampType)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("__dummy__", NullType)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+
+ def getDropDuplicatesWithColumnSchemas(): (StructType, StructType) = {
+ val metadata = getDropDuplicatesWithColumnSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getDropDuplicatesWithColumnSchemasWithMetadata(): ColumnFamilyMetadata =
{
+ val keySchema = StructType(Array(
+ StructField("col1", StringType, nullable = true)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("__dummy__", NullType, nullable = true)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+
+ def getDropDuplicatesWithinWatermarkSchemas(): (StructType, StructType) = {
+ val metadata = getDropDuplicatesWithinWatermarkSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getDropDuplicatesWithinWatermarkSchemasWithMetadata():
ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("_1", StringType, nullable = true)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("expiresAtMicros", LongType, nullable = false)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+/**
+ * Test utility object providing schema definitions for simple streaming
aggregation.
+ * Used by StatePartitionAllColumnFamiliesWriterSuite and
StatePartitionAllColumnFamiliesReaderSuite
+ * to eliminate code duplication.
+ */
+object SimpleAggregationTestUtils {
+ def getSchemas(stateVersion: Int): (StructType, StructType) = {
+ val metadata = getSchemasWithMetadata(stateVersion)
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ /**
+ * @param stateVersion The state format version:
+ * @return ColumnFamilyMetadata including schema and KeyEncoderSpec
+ */
+ def getSchemasWithMetadata(stateVersion: Int): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false)
+ ))
+
+ val valueSchema = if (stateVersion == 1) {
+ // State version 1 includes key columns in the value
+ StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ } else {
+ // State version 2 excludes key columns from the value
+ StructType(Array(
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ }
+
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+/**
+ * Test utility object providing schema definitions for flatMapGroupsWithState.
+ */
+object FlatMapGroupsWithStateTestUtils {
+
+ /**
+ * @param stateVersion The state format version:
+ * @return A tuple of (keySchema, valueSchema)
+ */
+ def getSchemas(stateVersion: Int): (StructType, StructType) = {
+ val metadata = getSchemasWithMetadata(stateVersion)
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ /**
+ * @param stateVersion The state format version
+ * @return ColumnFamilyMetadata with schema and KeyEncoderSpec
+ */
+ def getSchemasWithMetadata(stateVersion: Int): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("value", StringType, nullable = true)
+ ))
+
+ val valueSchema = if (stateVersion == 1) {
+ // State version 1: Flat structure
+ StructType(Array(
+ StructField("numEvents", IntegerType, nullable = false),
+ StructField("startTimestampMs", LongType, nullable = false),
+ StructField("endTimestampMs", LongType, nullable = false),
+ StructField("timeoutTimestamp", IntegerType, nullable = false)
+ ))
+ } else {
+ // State version 2: Nested structure with groupState wrapper
+ StructType(Array(
+ StructField("groupState", StructType(Array(
+ StructField("numEvents", IntegerType, nullable = false),
+ StructField("startTimestampMs", LongType, nullable = false),
+ StructField("endTimestampMs", LongType, nullable = false)
+ )), nullable = false),
+ StructField("timeoutTimestamp", LongType, nullable = false)
+ ))
+ }
+
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+/**
+ * Test utility object providing schema definitions for
runSessionWindowAggregationQuery
+ */
+object SessionWindowTestUtils {
+
+ /**
+ * @return A tuple of (keySchema, valueSchema)
+ */
+ def getSchemas(): (StructType, StructType) = {
+ val metadata = getSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ /**
+ * @return ColumnFamilyMetadata with schema and KeyEncoderSpec
Review Comment:
ditto, do for the remaining too
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala:
##########
@@ -123,6 +173,16 @@ class StatePartitionAllColumnFamiliesWriter(
val valueRow = new
UnsafeRow(columnFamilyToValueSchemaLenMap(colFamilyName))
valueRow.pointTo(valueBytes, valueBytes.length)
- stateStore.put(keyRow, valueRow, colFamilyName)
+ if (colFamilyToWriterInfoMap(colFamilyName).useMultipleValuesPerKey) {
+ // if a column family useMultipleValuesPerKey (e.g. ListType), we will
+ // write with 1 put followed by merge
+ if (stateStore.keyExists(keyRow, colFamilyName)) {
Review Comment:
These 2 ifs can be combined into one i.e.
```
// if a column family useMultipleValuesPerKey (e.g. ListType), we will
// write with 1 put followed by merge
if (colFamilyToWriterInfoMap(colFamilyName).useMultipleValuesPerKey &&
stateStore.keyExists(keyRow, colFamilyName)) {
stateStore.merge(keyRow, valueRow, colFamilyName)
} else {
stateStore.put(keyRow, valueRow, colFamilyName)
}
```
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala:
##########
@@ -457,3 +458,340 @@ case class SessionUpdate(
durationMs: Long,
numEvents: Int,
expired: Boolean)
+
+case class ColumnFamilyMetadata(
+ keySchema: StructType,
+ valueSchema: StructType,
+ encoderSpec: KeyStateEncoderSpec,
+ useMultipleValuePerKey: Boolean = false)
+
+// Utility for runCompositeKeyStreamingAggregationQuery
+// todo: Move runCompositeKeyStreamingAggregationQuery to this class
+object CompositeKeyAggregationTestUtils {
+ def getSchemas(stateVersion: Int): (StructType, StructType) = {
+ val metadata = getSchemasWithMetadata(stateVersion)
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getSchemasWithMetadata(stateVersion: Int): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("fruit", StringType, nullable = true)
+ ))
+
+ val valueSchema = if (stateVersion == 1) {
+ // State version 1 includes key columns in the value
+ StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("fruit", StringType, nullable = true),
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ } else {
+ // State version 2 excludes key columns from the value
+ StructType(Array(
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ }
+
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+// Utility for run DropDuplicatesQueries
+// todo: Move run DropDuplicatesQueries to this class
+object DropDuplicatesTestUtils {
+ def getDropDuplicatesSchemas(): (StructType, StructType) = {
+ val metadata = getDropDuplicatesSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getDropDuplicatesSchemasWithMetadata(): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("value", IntegerType, nullable = false),
+ StructField("eventTime", TimestampType)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("__dummy__", NullType)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+
+ def getDropDuplicatesWithColumnSchemas(): (StructType, StructType) = {
+ val metadata = getDropDuplicatesWithColumnSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getDropDuplicatesWithColumnSchemasWithMetadata(): ColumnFamilyMetadata =
{
+ val keySchema = StructType(Array(
+ StructField("col1", StringType, nullable = true)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("__dummy__", NullType, nullable = true)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+
+ def getDropDuplicatesWithinWatermarkSchemas(): (StructType, StructType) = {
+ val metadata = getDropDuplicatesWithinWatermarkSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getDropDuplicatesWithinWatermarkSchemasWithMetadata():
ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("_1", StringType, nullable = true)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("expiresAtMicros", LongType, nullable = false)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+/**
+ * Test utility object providing schema definitions for simple streaming
aggregation.
+ * Used by StatePartitionAllColumnFamiliesWriterSuite and
StatePartitionAllColumnFamiliesReaderSuite
+ * to eliminate code duplication.
+ */
+object SimpleAggregationTestUtils {
+ def getSchemas(stateVersion: Int): (StructType, StructType) = {
+ val metadata = getSchemasWithMetadata(stateVersion)
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ /**
+ * @param stateVersion The state format version:
+ * @return ColumnFamilyMetadata including schema and KeyEncoderSpec
+ */
+ def getSchemasWithMetadata(stateVersion: Int): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false)
+ ))
+
+ val valueSchema = if (stateVersion == 1) {
+ // State version 1 includes key columns in the value
+ StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ } else {
+ // State version 2 excludes key columns from the value
+ StructType(Array(
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ }
+
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+/**
+ * Test utility object providing schema definitions for flatMapGroupsWithState.
+ */
+object FlatMapGroupsWithStateTestUtils {
+
+ /**
+ * @param stateVersion The state format version:
+ * @return A tuple of (keySchema, valueSchema)
+ */
+ def getSchemas(stateVersion: Int): (StructType, StructType) = {
+ val metadata = getSchemasWithMetadata(stateVersion)
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ /**
+ * @param stateVersion The state format version
Review Comment:
ditto
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/TimerTestUtils.scala:
##########
@@ -0,0 +1,170 @@
+/*
Review Comment:
Lets combine these 3 utils into one file. I mean the MultiStateVarTestUtils
file + TimerTestUtils + TTLProcessorUtils. No need for separate files, they are
all for TWS testing and doing similar things. Makes them easier to find.
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/TimerTestUtils.scala:
##########
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state.utils
+
+import java.sql.Timestamp
+
+import org.apache.spark.sql.Encoders
+import
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.timers.TimerStateUtils
+import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec,
NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec,
RangeKeyScanStateEncoderSpec, StateStore}
+import org.apache.spark.sql.streaming.{OutputMode, StatefulProcessor,
TimeMode, TimerValues, TTLConfig, ValueState}
+import org.apache.spark.sql.types.{BinaryType, LongType, NullType, StringType,
StructField, StructType}
+
+class EventTimeTimerProcessor
+ extends StatefulProcessor[String, (String, Timestamp), (String, String)] {
+ @transient var _valueState: ValueState[Long] = _
+
+ override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
+ _valueState = getHandle.getValueState("countState", Encoders.scalaLong,
TTLConfig.NONE)
+ }
+
+ override def handleInputRows(
+ key: String,
+ rows: Iterator[(String, Timestamp)],
+ timerValues: TimerValues): Iterator[(String, String)] = {
+ var maxTimestamp = 0L
+ var rowCount = 0
+ rows.foreach { case (_, timestamp) =>
+ maxTimestamp = Math.max(maxTimestamp, timestamp.getTime)
+ rowCount += 1
+ }
+
+ val count = Option(_valueState.get()).getOrElse(0L) + rowCount
+ _valueState.update(count)
+
+ // Register an event time timer
+ if (maxTimestamp > 0) {
+ getHandle.registerTimer(maxTimestamp + 5000)
+ }
+
+ Iterator((key, count.toString))
+ }
+}
+
+/**
+ * Test utility providing schema definitions and constants for
EventTimeTimerProcessor
+ * and RunningCountStatefulProcessorWithProcTimeTimer in
TransformWithStateSuite
+ */
+object TimerTestUtils {
+ case class ColumnFamilyMetadata(
Review Comment:
this is defined in several places. Just define it in one place
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/TimerTestUtils.scala:
##########
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state.utils
+
+import java.sql.Timestamp
+
+import org.apache.spark.sql.Encoders
+import
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.timers.TimerStateUtils
+import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec,
NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec,
RangeKeyScanStateEncoderSpec, StateStore}
+import org.apache.spark.sql.streaming.{OutputMode, StatefulProcessor,
TimeMode, TimerValues, TTLConfig, ValueState}
+import org.apache.spark.sql.types.{BinaryType, LongType, NullType, StringType,
StructField, StructType}
+
+class EventTimeTimerProcessor
+ extends StatefulProcessor[String, (String, Timestamp), (String, String)] {
+ @transient var _valueState: ValueState[Long] = _
+
+ override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
+ _valueState = getHandle.getValueState("countState", Encoders.scalaLong,
TTLConfig.NONE)
+ }
+
+ override def handleInputRows(
+ key: String,
+ rows: Iterator[(String, Timestamp)],
+ timerValues: TimerValues): Iterator[(String, String)] = {
+ var maxTimestamp = 0L
+ var rowCount = 0
+ rows.foreach { case (_, timestamp) =>
+ maxTimestamp = Math.max(maxTimestamp, timestamp.getTime)
+ rowCount += 1
+ }
+
+ val count = Option(_valueState.get()).getOrElse(0L) + rowCount
+ _valueState.update(count)
+
+ // Register an event time timer
+ if (maxTimestamp > 0) {
+ getHandle.registerTimer(maxTimestamp + 5000)
+ }
+
+ Iterator((key, count.toString))
+ }
+}
+
+/**
+ * Test utility providing schema definitions and constants for
EventTimeTimerProcessor
+ * and RunningCountStatefulProcessorWithProcTimeTimer in
TransformWithStateSuite
+ */
+object TimerTestUtils {
+ case class ColumnFamilyMetadata(
+ keySchema: StructType,
+ valueSchema: StructType,
+ encoderSpec: KeyStateEncoderSpec,
+ useMultipleValuePerKey: Boolean = false)
+
+ /**
+ * Returns the grouping key schema and state value schema for a simple count
state.
+ * This is commonly used with timer tests where state tracks counts.
+ *
+ * @return A tuple of (keySchema, valueSchema)
+ */
+ def getCountStateSchemas(): (StructType, StructType) = {
+ val groupByKeySchema = StructType(Array(
+ StructField("key", StringType, nullable = true)
+ ))
+ val stateValueSchema = StructType(Array(
+ StructField("value", LongType, nullable = true)
+ ))
+ (groupByKeySchema, stateValueSchema)
+ }
+
+ /**
+ * Returns schemas for timer-related column families
+ *
+ * @param groupingKeySchema The schema for the grouping key
+ * @return A tuple of (keyToTimestampKeySchema, timestampToKeyKeySchema)
+ */
+ def getTimerKeySchemas(groupingKeySchema: StructType): (StructType,
StructType) = {
+ val keyToTimestampKeySchema = StructType(Array(
+ StructField("key", groupingKeySchema),
+ StructField("expiryTimestampMs", LongType, nullable = false)
+ ))
+ val timestampToKeyKeySchema = StructType(Array(
+ StructField("expiryTimestampMs", LongType, nullable = false),
+ StructField("key", groupingKeySchema)
+ ))
+
+ (keyToTimestampKeySchema, timestampToKeyKeySchema)
+ }
+
+ /**
+ * Returns complete metadata for timer-related column families including
encoder specs.
+ * Used for tests with timers and a count state.
+ * @param timeMode The time mode (EventTime or ProcessingTime)
+ * @return A map from column family name to ColumnFamilyMetadata
+ */
+ def getTimerConfigsForCountState(timeMode: TimeMode): Map[String,
ColumnFamilyMetadata] = {
+ val (groupByKeySchema, stateValueSchema) = getCountStateSchemas()
+ val (keyToTimestampKeySchema, timestampToKeyKeySchema) =
getTimerKeySchemas(groupByKeySchema)
+ val defaultValueSchema = StructType(Array(StructField("value", BinaryType,
nullable = true)))
+
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(groupByKeySchema)
+ val keyToTimestampEncoderSpec =
PrefixKeyScanStateEncoderSpec(keyToTimestampKeySchema, 1)
+ val timestampToKeyEncoderSpec =
RangeKeyScanStateEncoderSpec(timestampToKeyKeySchema, Seq(0))
+
+ val (keyToTimestampCF, timestampToKeyCF) =
+ TimerStateUtils.getTimerStateVarNames(timeMode.toString)
+
+ val dummyValueSchema = StructType(Array(StructField("__dummy__",
NullType)))
+
+ Map(
+ StateStore.DEFAULT_COL_FAMILY_NAME -> ColumnFamilyMetadata(
+ groupByKeySchema, defaultValueSchema, encoderSpec),
+ "countState" -> ColumnFamilyMetadata(
+ groupByKeySchema, stateValueSchema, encoderSpec),
+ keyToTimestampCF -> ColumnFamilyMetadata(
+ keyToTimestampKeySchema, dummyValueSchema, keyToTimestampEncoderSpec),
+ timestampToKeyCF -> ColumnFamilyMetadata(
+ timestampToKeyKeySchema, dummyValueSchema, timestampToKeyEncoderSpec)
+ )
+ }
+
+ /**
+ * Returns select expressions for timer column families with standardized
column order.
+ *
+ * @param columnFamilyName The timer column family name
+ * @return Select expressions as a Seq of strings with order (partition_id,
key, value)
+ */
+ def getTimerSelectExpressions(columnFamilyName: String): Seq[String] = {
+ if (columnFamilyName.endsWith("_keyToTimestamp")) {
Review Comment:
This is already defined in `TimerStateUtils.KEY_TO_TIMESTAMP_CF`
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala:
##########
@@ -457,3 +458,340 @@ case class SessionUpdate(
durationMs: Long,
numEvents: Int,
expired: Boolean)
+
+case class ColumnFamilyMetadata(
+ keySchema: StructType,
+ valueSchema: StructType,
+ encoderSpec: KeyStateEncoderSpec,
+ useMultipleValuePerKey: Boolean = false)
+
+// Utility for runCompositeKeyStreamingAggregationQuery
+// todo: Move runCompositeKeyStreamingAggregationQuery to this class
+object CompositeKeyAggregationTestUtils {
+ def getSchemas(stateVersion: Int): (StructType, StructType) = {
+ val metadata = getSchemasWithMetadata(stateVersion)
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getSchemasWithMetadata(stateVersion: Int): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("fruit", StringType, nullable = true)
+ ))
+
+ val valueSchema = if (stateVersion == 1) {
+ // State version 1 includes key columns in the value
+ StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("fruit", StringType, nullable = true),
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ } else {
+ // State version 2 excludes key columns from the value
+ StructType(Array(
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ }
+
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+// Utility for run DropDuplicatesQueries
+// todo: Move run DropDuplicatesQueries to this class
+object DropDuplicatesTestUtils {
+ def getDropDuplicatesSchemas(): (StructType, StructType) = {
+ val metadata = getDropDuplicatesSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getDropDuplicatesSchemasWithMetadata(): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("value", IntegerType, nullable = false),
+ StructField("eventTime", TimestampType)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("__dummy__", NullType)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+
+ def getDropDuplicatesWithColumnSchemas(): (StructType, StructType) = {
+ val metadata = getDropDuplicatesWithColumnSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getDropDuplicatesWithColumnSchemasWithMetadata(): ColumnFamilyMetadata =
{
+ val keySchema = StructType(Array(
+ StructField("col1", StringType, nullable = true)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("__dummy__", NullType, nullable = true)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+
+ def getDropDuplicatesWithinWatermarkSchemas(): (StructType, StructType) = {
+ val metadata = getDropDuplicatesWithinWatermarkSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ def getDropDuplicatesWithinWatermarkSchemasWithMetadata():
ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("_1", StringType, nullable = true)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("expiresAtMicros", LongType, nullable = false)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+/**
+ * Test utility object providing schema definitions for simple streaming
aggregation.
+ * Used by StatePartitionAllColumnFamiliesWriterSuite and
StatePartitionAllColumnFamiliesReaderSuite
+ * to eliminate code duplication.
+ */
+object SimpleAggregationTestUtils {
+ def getSchemas(stateVersion: Int): (StructType, StructType) = {
+ val metadata = getSchemasWithMetadata(stateVersion)
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ /**
+ * @param stateVersion The state format version:
+ * @return ColumnFamilyMetadata including schema and KeyEncoderSpec
+ */
+ def getSchemasWithMetadata(stateVersion: Int): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false)
+ ))
+
+ val valueSchema = if (stateVersion == 1) {
+ // State version 1 includes key columns in the value
+ StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ } else {
+ // State version 2 excludes key columns from the value
+ StructType(Array(
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ }
+
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+/**
+ * Test utility object providing schema definitions for flatMapGroupsWithState.
+ */
+object FlatMapGroupsWithStateTestUtils {
+
+ /**
+ * @param stateVersion The state format version:
+ * @return A tuple of (keySchema, valueSchema)
+ */
+ def getSchemas(stateVersion: Int): (StructType, StructType) = {
+ val metadata = getSchemasWithMetadata(stateVersion)
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ /**
+ * @param stateVersion The state format version
+ * @return ColumnFamilyMetadata with schema and KeyEncoderSpec
+ */
+ def getSchemasWithMetadata(stateVersion: Int): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("value", StringType, nullable = true)
+ ))
+
+ val valueSchema = if (stateVersion == 1) {
+ // State version 1: Flat structure
+ StructType(Array(
+ StructField("numEvents", IntegerType, nullable = false),
+ StructField("startTimestampMs", LongType, nullable = false),
+ StructField("endTimestampMs", LongType, nullable = false),
+ StructField("timeoutTimestamp", IntegerType, nullable = false)
+ ))
+ } else {
+ // State version 2: Nested structure with groupState wrapper
+ StructType(Array(
+ StructField("groupState", StructType(Array(
+ StructField("numEvents", IntegerType, nullable = false),
+ StructField("startTimestampMs", LongType, nullable = false),
+ StructField("endTimestampMs", LongType, nullable = false)
+ )), nullable = false),
+ StructField("timeoutTimestamp", LongType, nullable = false)
+ ))
+ }
+
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+/**
+ * Test utility object providing schema definitions for
runSessionWindowAggregationQuery
+ */
+object SessionWindowTestUtils {
+
+ /**
+ * @return A tuple of (keySchema, valueSchema)
+ */
+ def getSchemas(): (StructType, StructType) = {
+ val metadata = getSchemasWithMetadata()
+ (metadata.keySchema, metadata.valueSchema)
+ }
+
+ /**
+ * @return ColumnFamilyMetadata with schema and KeyEncoderSpec
+ */
+ def getSchemasWithMetadata(): ColumnFamilyMetadata = {
+ val keySchema = StructType(Array(
+ StructField("sessionId", StringType, nullable = false),
+ StructField("sessionStartTime", TimestampType, nullable = false)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("session_window", StructType(Array(
+ StructField("start", TimestampType),
+ StructField("end", TimestampType)
+ )), nullable = false),
+ StructField("sessionId", StringType, nullable = false),
+ StructField("count", LongType, nullable = false)
+ ))
+ val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+ ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
+ }
+}
+
+/**
+ * Test utility object providing schema definitions and constants for
runStreamStreamJoinQuery
+ */
+object StreamStreamJoinTestUtils {
+ // Column family names for keyToNumValues stores
+ val KEY_TO_NUM_VALUES_LEFT = "left-keyToNumValues"
+ val KEY_TO_NUM_VALUES_RIGHT = "right-keyToNumValues"
+ val KEY_TO_NUM_VALUES_ALL: Seq[String] = Seq(
+ KEY_TO_NUM_VALUES_LEFT,
+ KEY_TO_NUM_VALUES_RIGHT
+ )
+
+ // Column family names for keyWithIndexToValue stores
+ val KEY_WITH_INDEX_LEFT = "left-keyWithIndexToValue"
Review Comment:
ditto
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]