This is an automated email from the ASF dual-hosted git repository.
ashrigondekar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new df63cb73215a [SPARK-54388][SS] Introduce StatePartitionReader that
scan raw bytes for Single ColFamily
df63cb73215a is described below
commit df63cb73215aba83c586aec1ae44e19a5aeabc39
Author: zifeif2 <[email protected]>
AuthorDate: Wed Dec 3 09:40:22 2025 -0800
[SPARK-54388][SS] Introduce StatePartitionReader that scan raw bytes for
Single ColFamily
### What changes were proposed in this pull request?
Introducing a new StatePartitionReader -
StatePartitionReaderAllColumnFamilies to support offline repartition.
StatePartitionReaderAllColumnFamilies is invoked when user specify option
`readAllColumnFamilies` to true.
We have the StateDataSource Reader, which allows customers to read the rows
in an operator state store using the DataFrame API, just like they read a
normal table. But it currently only supports reading one column family in the
state store at a time.
We would introduce a change to allow reading all the state rows in all the
column families, so that we can repartition them at once. This would allow us
to read the entire state store, repartition the rows, and then save the new
repartition state rows to the cloud. This also has a perf impact, since we
don’t have to read each column family separately. We would read the state based
on the last committed batch version.
Since each column family can have a different schema, the DataFrame we will
return will treat the key and value row as bytes -
- partition_key (string)
- key_bytes (binary)
- value_bytes (binary)
- column_family_name (string)
### Why are the changes needed?
See above
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
See unit test. It not only verify the schema, but also validate the data
are serialized to bytes correctly by comparing them against the normal queried
data frame
### Was this patch authored or co-authored using generative AI tooling?
Yes. haiku, sonnet.
Closes #53104 from zifeif2/repartition-reader-single-cf.
Lead-authored-by: zifeif2 <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
Signed-off-by: Anish Shrigondekar <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 5 +
.../datasources/v2/state/StateDataSource.scala | 46 +-
.../v2/state/StatePartitionReader.scala | 57 ++-
.../datasources/v2/state/utils/SchemaUtil.scala | 39 +-
.../state/OfflineStateRepartitionErrors.scala | 14 +
...tatePartitionAllColumnFamiliesReaderSuite.scala | 536 +++++++++++++++++++++
6 files changed, 688 insertions(+), 9 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index 27d4552758f2..4bd4f3cfc764 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -5515,6 +5515,11 @@
"message" : [
"Unsupported offset sequence version <version>. Please make sure the
checkpoint is from a supported Spark version (Spark 4.0+)."
]
+ },
+ "UNSUPPORTED_PROVIDER" : {
+ "message" : [
+ "<provider> is not supported"
+ ]
}
},
"sqlState" : "55019"
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
index c97a70eb3c8c..ee46480a7a12 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
@@ -41,7 +41,8 @@ import
org.apache.spark.sql.execution.streaming.operators.stateful.transformwith
import
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.timers.TimerStateUtils
import
org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.DIR_NAME_STATE
import
org.apache.spark.sql.execution.streaming.runtime.StreamingQueryCheckpointMetadata
-import
org.apache.spark.sql.execution.streaming.state.{InMemoryStateSchemaProvider,
KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec,
PrefixKeyScanStateEncoderSpec, StateSchemaCompatibilityChecker,
StateSchemaMetadata, StateSchemaProvider, StateStore,
StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId}
+import
org.apache.spark.sql.execution.streaming.state.{InMemoryStateSchemaProvider,
KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec,
PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider,
StateSchemaCompatibilityChecker, StateSchemaMetadata, StateSchemaProvider,
StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId,
StateStoreProviderId}
+import
org.apache.spark.sql.execution.streaming.state.OfflineStateRepartitionErrors
import org.apache.spark.sql.execution.streaming.utils.StreamingUtils
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.streaming.TimeMode
@@ -66,6 +67,14 @@ class StateDataSource extends TableProvider with
DataSourceRegister with Logging
val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf,
StateSourceOptions.apply(session, hadoopConf, properties))
val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation,
sourceOptions.batchId)
+ // We only support RocksDB because the repartition work that this option
+ // is built for only supports RocksDB
+ if (sourceOptions.internalOnlyReadAllColumnFamilies
+ && stateConf.providerClass !=
classOf[RocksDBStateStoreProvider].getName) {
+ throw OfflineStateRepartitionErrors.unsupportedStateStoreProviderError(
+ sourceOptions.resolvedCpLocation,
+ stateConf.providerClass)
+ }
val stateStoreReaderInfo: StateStoreReaderInfo =
getStoreMetadataAndRunChecks(
sourceOptions)
@@ -372,6 +381,7 @@ case class StateSourceOptions(
stateVarName: Option[String],
readRegisteredTimers: Boolean,
flattenCollectionTypes: Boolean,
+ internalOnlyReadAllColumnFamilies: Boolean = false,
startOperatorStateUniqueIds: Option[Array[Array[String]]] = None,
endOperatorStateUniqueIds: Option[Array[Array[String]]] = None) {
def stateCheckpointLocation: Path = new Path(resolvedCpLocation,
DIR_NAME_STATE)
@@ -379,7 +389,7 @@ case class StateSourceOptions(
override def toString: String = {
var desc = s"StateSourceOptions(checkpointLocation=$resolvedCpLocation,
batchId=$batchId, " +
s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide, " +
- s"stateVarName=${stateVarName.getOrElse("None")}, +" +
+ s"stateVarName=${stateVarName.getOrElse("None")}, " +
s"flattenCollectionTypes=$flattenCollectionTypes"
if (fromSnapshotOptions.isDefined) {
desc += s",
snapshotStartBatchId=${fromSnapshotOptions.get.snapshotStartBatchId}"
@@ -393,7 +403,7 @@ case class StateSourceOptions(
}
}
-object StateSourceOptions extends DataSourceOptions {
+object StateSourceOptions extends DataSourceOptions with Logging{
val PATH = newOption("path")
val BATCH_ID = newOption("batchId")
val OPERATOR_ID = newOption("operatorId")
@@ -407,6 +417,7 @@ object StateSourceOptions extends DataSourceOptions {
val STATE_VAR_NAME = newOption("stateVarName")
val READ_REGISTERED_TIMERS = newOption("readRegisteredTimers")
val FLATTEN_COLLECTION_TYPES = newOption("flattenCollectionTypes")
+ val INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES =
newOption("_readAllColumnFamilies")
object JoinSideValues extends Enumeration {
type JoinSideValues = Value
@@ -492,6 +503,33 @@ object StateSourceOptions extends DataSourceOptions {
val readChangeFeed =
Option(options.get(READ_CHANGE_FEED)).exists(_.toBoolean)
+ val internalOnlyReadAllColumnFamilies = try {
+
Option(options.get(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES)).exists(_.toBoolean)
+ } catch {
+ case _: IllegalArgumentException =>
+ throw
StateDataSourceErrors.invalidOptionValue(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES,
+ "Boolean value is expected")
+ }
+
+ // This config should only be used by internal callers e.g. repartitioning
+ if (internalOnlyReadAllColumnFamilies) {
+ logWarning("StateSourceOptions option
INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES is enabled. " +
+ "This config should only be used for internal callers e.g.
repartitioning")
+ if (stateVarName.isDefined) {
+ throw StateDataSourceErrors.conflictOptions(
+ Seq(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, STATE_VAR_NAME))
+ }
+ // Use storeName rather than joinSide to identify the specific join store
+ if (joinSide != JoinSideValues.none) {
+ throw StateDataSourceErrors.conflictOptions(
+ Seq(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, JOIN_SIDE))
+ }
+ if (readChangeFeed) {
+ throw StateDataSourceErrors.conflictOptions(
+ Seq(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, READ_CHANGE_FEED))
+ }
+ }
+
val changeStartBatchId =
Option(options.get(CHANGE_START_BATCH_ID)).map(_.toLong)
var changeEndBatchId =
Option(options.get(CHANGE_END_BATCH_ID)).map(_.toLong)
@@ -615,7 +653,7 @@ object StateSourceOptions extends DataSourceOptions {
StateSourceOptions(
resolvedCpLocation, batchId.get, operatorId, storeName, joinSide,
readChangeFeed, fromSnapshotOptions, readChangeFeedOptions,
- stateVarName, readRegisteredTimers, flattenCollectionTypes,
+ stateVarName, readRegisteredTimers, flattenCollectionTypes,
internalOnlyReadAllColumnFamilies,
startOperatorStateUniqueIds, endOperatorStateUniqueIds)
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
index 619e374c00de..9fc3c081173f 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
@@ -49,7 +49,10 @@ class StatePartitionReaderFactory(
override def createReader(partition: InputPartition):
PartitionReader[InternalRow] = {
val stateStoreInputPartition =
partition.asInstanceOf[StateStoreInputPartition]
- if (stateStoreInputPartition.sourceOptions.readChangeFeed) {
+ if
(stateStoreInputPartition.sourceOptions.internalOnlyReadAllColumnFamilies) {
+ new StatePartitionAllColumnFamiliesReader(storeConf, hadoopConf,
+ stateStoreInputPartition, schema, keyStateEncoderSpec,
stateStoreColFamilySchemaOpt)
+ } else if (stateStoreInputPartition.sourceOptions.readChangeFeed) {
new StateStoreChangeDataPartitionReader(storeConf, hadoopConf,
stateStoreInputPartition, schema, keyStateEncoderSpec,
stateVariableInfoOpt,
stateStoreColFamilySchemaOpt, stateSchemaProviderOpt, joinColFamilyOpt)
@@ -81,16 +84,22 @@ abstract class StatePartitionReaderBase(
private val schemaForValueRow: StructType =
StructType(Array(StructField("__dummy__", NullType)))
- protected val keySchema = {
+ protected val keySchema : StructType = {
if (SchemaUtil.checkVariableType(stateVariableInfoOpt,
StateVariableType.MapState)) {
SchemaUtil.getCompositeKeySchema(schema, partition.sourceOptions)
+ } else if (partition.sourceOptions.internalOnlyReadAllColumnFamilies) {
+ require(stateStoreColFamilySchemaOpt.isDefined)
+ stateStoreColFamilySchemaOpt.map(_.keySchema).get
} else {
SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType]
}
}
- protected val valueSchema = if (stateVariableInfoOpt.isDefined) {
+ protected val valueSchema : StructType = if (stateVariableInfoOpt.isDefined)
{
schemaForValueRow
+ } else if (partition.sourceOptions.internalOnlyReadAllColumnFamilies) {
+ require(stateStoreColFamilySchemaOpt.isDefined)
+ stateStoreColFamilySchemaOpt.map(_.valueSchema).get
} else {
SchemaUtil.getSchemaAsDataType(
schema, "value").asInstanceOf[StructType]
@@ -237,6 +246,48 @@ class StatePartitionReader(
}
}
+/**
+ * An implementation of [[StatePartitionReaderBase]] for reading all column
families
+ * in binary format. This reader returns raw key and value bytes along with
column family names.
+ * We are returning key/value bytes because each column family can have
different schema
+ * It will also return the partition key
+ */
+class StatePartitionAllColumnFamiliesReader(
+ storeConf: StateStoreConf,
+ hadoopConf: SerializableConfiguration,
+ partition: StateStoreInputPartition,
+ schema: StructType,
+ keyStateEncoderSpec: KeyStateEncoderSpec,
+ stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema])
+ extends StatePartitionReaderBase(
+ storeConf,
+ hadoopConf, partition, schema,
+ keyStateEncoderSpec, None, stateStoreColFamilySchemaOpt, None, None) {
+
+ private lazy val store: ReadStateStore = {
+ assert(getStartStoreUniqueId == getEndStoreUniqueId,
+ "Start and end store unique IDs must be the same when reading all column
families")
+ provider.getReadStore(
+ partition.sourceOptions.batchId + 1,
+ getStartStoreUniqueId
+ )
+ }
+
+ override lazy val iter: Iterator[InternalRow] = {
+ store
+ .iterator()
+ .map { pair =>
+ SchemaUtil.unifyStateRowPairAsRawBytes(
+ (pair.key, pair.value), StateStore.DEFAULT_COL_FAMILY_NAME)
+ }
+ }
+
+ override def close(): Unit = {
+ store.release()
+ super.close()
+ }
+}
+
/**
* An implementation of [[StatePartitionReaderBase]] for the readChangeFeed
mode of State Data
* Source. It reads the change of state over batches of a particular partition.
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
index 52df016791d4..44d83fc99b57 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
@@ -28,7 +28,8 @@ import
org.apache.spark.sql.execution.datasources.v2.state.{StateDataSourceError
import
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType,
TransformWithStateVariableInfo}
import
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.StateVariableType._
import org.apache.spark.sql.execution.streaming.state.{ReadStateStore,
StateStoreColFamilySchema, UnsafeRowPair}
-import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, LongType,
MapType, StringType, StructType}
+import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType,
IntegerType, LongType, MapType, StringType, StructType}
+import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ArrayImplicits._
object SchemaUtil {
@@ -60,6 +61,14 @@ object SchemaUtil {
.add("key", keySchema)
.add("value", valueSchema)
.add("partition_id", IntegerType)
+ } else if (sourceOptions.internalOnlyReadAllColumnFamilies) {
+ new StructType()
+ // todo [SPARK-54443]: change keySchema to a more specific type after
we
+ // can extract partition key from keySchema
+ .add("partition_key", keySchema)
+ .add("key_bytes", BinaryType)
+ .add("value_bytes", BinaryType)
+ .add("column_family_name", StringType)
} else {
new StructType()
.add("key", keySchema)
@@ -76,6 +85,26 @@ object SchemaUtil {
row
}
+ /**
+ * Returns an InternalRow representing
+ * 1. partitionKey
+ * 2. key in bytes
+ * 3. value in bytes
+ * 4. column family name
+ */
+ def unifyStateRowPairAsRawBytes(
+ pair: (UnsafeRow, UnsafeRow),
+ colFamilyName: String): InternalRow = {
+ val row = new GenericInternalRow(4)
+ // todo [SPARK-54443]: change keySchema to more specific type after we
+ // can extract partition key from keySchema
+ row.update(0, pair._1)
+ row.update(1, pair._1.getBytes)
+ row.update(2, pair._2.getBytes)
+ row.update(3, UTF8String.fromString(colFamilyName))
+ row
+ }
+
def unifyStateRowPairWithMultipleValues(
pair: (UnsafeRow, GenericArrayData),
partition: Int): InternalRow = {
@@ -231,7 +260,11 @@ object SchemaUtil {
"user_map_key" -> classOf[StructType],
"user_map_value" -> classOf[StructType],
"expiration_timestamp_ms" -> classOf[LongType],
- "partition_id" -> classOf[IntegerType])
+ "partition_id" -> classOf[IntegerType],
+ "partition_key" -> classOf[StructType],
+ "key_bytes" -> classOf[BinaryType],
+ "value_bytes" -> classOf[BinaryType],
+ "column_family_name" -> classOf[StringType])
val expectedFieldNames = if (transformWithStateVariableInfoOpt.isDefined) {
val stateVarInfo = transformWithStateVariableInfoOpt.get
@@ -272,6 +305,8 @@ object SchemaUtil {
}
} else if (sourceOptions.readChangeFeed) {
Seq("batch_id", "change_type", "key", "value", "partition_id")
+ } else if (sourceOptions.internalOnlyReadAllColumnFamilies) {
+ Seq("partition_key", "key_bytes", "value_bytes", "column_family_name")
} else {
Seq("key", "value", "partition_id")
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala
index 95b273826877..0e9b8ad8a63b 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala
@@ -85,6 +85,12 @@ object OfflineStateRepartitionErrors {
version: Int): StateRepartitionInvalidCheckpointError = {
new StateRepartitionUnsupportedOffsetSeqVersionError(checkpointLocation,
version)
}
+
+ def unsupportedStateStoreProviderError(
+ checkpointLocation: String,
+ providerClass: String): StateRepartitionInvalidCheckpointError = {
+ new StateRepartitionUnsupportedProviderError(checkpointLocation,
providerClass)
+ }
}
/**
@@ -201,3 +207,11 @@ class StateRepartitionUnsupportedOffsetSeqVersionError(
checkpointLocation,
subClass = "UNSUPPORTED_OFFSET_SEQ_VERSION",
messageParameters = Map("version" -> version.toString))
+
+class StateRepartitionUnsupportedProviderError(
+ checkpointLocation: String,
+ provider: String)
+ extends StateRepartitionInvalidCheckpointError(
+ checkpointLocation,
+ subClass = "UNSUPPORTED_PROVIDER",
+ messageParameters = Map("provider" -> provider))
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala
new file mode 100644
index 000000000000..c4b59b149b96
--- /dev/null
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala
@@ -0,0 +1,536 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state
+
+import java.nio.ByteOrder
+import java.util.Arrays
+
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
+import
org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider,
RocksDBStateStoreProvider, StateRepartitionUnsupportedProviderError, StateStore}
+import org.apache.spark.sql.functions.{count, sum}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType,
NullType, StructField, StructType, TimestampType}
+
+/**
+ * Note: This extends StateDataSourceTestBase to access
+ * helper methods like runDropDuplicatesQuery without inheriting all
predefined tests.
+ */
+class StatePartitionAllColumnFamiliesReaderSuite extends
StateDataSourceTestBase {
+
+ import testImplicits._
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
+ classOf[RocksDBStateStoreProvider].getName)
+ }
+
+ private def getNormalReadDf(
+ checkpointDir: String,
+ storeName: Option[String] = Option.empty[String]): DataFrame = {
+ spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, checkpointDir)
+ .option(StateSourceOptions.STORE_NAME, storeName.orNull)
+ .load()
+ .selectExpr("partition_id", "key", "value")
+ }
+
+ private def getBytesReadDf(
+ checkpointDir: String,
+ storeName: Option[String] = Option.empty[String]): DataFrame = {
+ spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, checkpointDir)
+ .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES,
"true")
+ .option(StateSourceOptions.STORE_NAME, storeName.orNull)
+ .load()
+ }
+
+ /**
+ * Validates the schema and column families of the bytes read DataFrame.
+ */
+ private def validateBytesReadDfSchema(df: DataFrame): Unit = {
+ // Verify schema
+ val schema = df.schema
+ assert(schema.fieldNames === Array(
+ "partition_key", "key_bytes", "value_bytes", "column_family_name"))
+ assert(schema("partition_key").dataType.typeName === "struct")
+ assert(schema("key_bytes").dataType.typeName === "binary")
+ assert(schema("value_bytes").dataType.typeName === "binary")
+ assert(schema("column_family_name").dataType.typeName === "string")
+ }
+
+ /**
+ * Compares normal read data with bytes read data for a specific column
family.
+ * Converts normal rows to bytes then compares with bytes read.
+ */
+ private def compareNormalAndBytesData(
+ normalDf: Array[Row],
+ bytesDf: Array[Row],
+ columnFamily: String,
+ keySchema: StructType,
+ valueSchema: StructType): Unit = {
+
+ // Filter bytes data for the specified column family and extract raw bytes
directly
+ val filteredBytesData = bytesDf.filter { row =>
+ row.getString(3) == columnFamily
+ }
+
+ // Verify same number of rows
+ assert(filteredBytesData.length == normalDf.length,
+ s"Row count mismatch for column family '$columnFamily': " +
+ s"normal read has ${normalDf.length} rows, " +
+ s"bytes read has ${filteredBytesData.length} rows")
+
+ // Create projections to convert Row to UnsafeRow bytes
+ val keyProjection = UnsafeProjection.create(keySchema)
+ val valueProjection = UnsafeProjection.create(valueSchema)
+
+ // Create converters to convert external Row types to internal Catalyst
types
+ val keyConverter =
CatalystTypeConverters.createToCatalystConverter(keySchema)
+ val valueConverter =
CatalystTypeConverters.createToCatalystConverter(valueSchema)
+
+ // Convert normal data to bytes
+ val normalAsBytes = normalDf.toSeq.map { row =>
+ val key = row.getStruct(1)
+ val value = if (row.isNullAt(2)) null else row.getStruct(2)
+
+ // Convert key to InternalRow, then to UnsafeRow, then get bytes
+ val keyInternalRow = keyConverter(key).asInstanceOf[InternalRow]
+ val keyUnsafeRow = keyProjection(keyInternalRow)
+ // IMPORTANT: Must clone the bytes array since getBytes() returns a
reference
+ // that may be overwritten by subsequent UnsafeRow operations
+ val keyBytes = keyUnsafeRow.getBytes.clone()
+
+ // Convert value to bytes
+ val valueBytes = if (value == null) {
+ Array.empty[Byte]
+ } else {
+ val valueInternalRow = valueConverter(value).asInstanceOf[InternalRow]
+ val valueUnsafeRow = valueProjection(valueInternalRow)
+ // IMPORTANT: Must clone the bytes array since getBytes() returns a
reference
+ // that may be overwritten by subsequent UnsafeRow operations
+ valueUnsafeRow.getBytes.clone()
+ }
+
+ (keyBytes, valueBytes)
+ }
+
+ // Extract raw bytes from bytes read data (no
deserialization/reserialization)
+ val bytesAsBytes = filteredBytesData.map { row =>
+ val keyBytes = row.getAs[Array[Byte]](1)
+ val valueBytes = row.getAs[Array[Byte]](2)
+ (keyBytes, valueBytes)
+ }
+
+ // Sort both for comparison (since Set equality doesn't work well with
byte arrays)
+ val normalSorted = normalAsBytes.sortBy(x => (x._1.mkString(","),
x._2.mkString(",")))
+ val bytesSorted = bytesAsBytes.sortBy(x => (x._1.mkString(","),
x._2.mkString(",")))
+
+ assert(normalSorted.length == bytesSorted.length,
+ s"Size mismatch: normal has ${normalSorted.length}, bytes has
${bytesSorted.length}")
+
+ // Compare each pair
+ normalSorted.zip(bytesSorted).zipWithIndex.foreach {
+ case (((normalKey, normalValue), (bytesKey, bytesValue)), idx) =>
+ assert(Arrays.equals(normalKey, bytesKey),
+ s"Key mismatch at index $idx:\n" +
+ s" Normal: ${normalKey.mkString("[", ",", "]")}\n" +
+ s" Bytes: ${bytesKey.mkString("[", ",", "]")}")
+ assert(Arrays.equals(normalValue, bytesValue),
+ s"Value mismatch at index $idx:\n" +
+ s" Normal: ${normalValue.mkString("[", ",", "]")}\n" +
+ s" Bytes: ${bytesValue.mkString("[", ",", "]")}")
+ }
+ }
+
+ // Run all tests with both changelog checkpointing enabled and disabled
+ Seq(true, false).foreach { changelogCheckpointingEnabled =>
+ val testSuffix = if (changelogCheckpointingEnabled) {
+ "with changelog checkpointing"
+ } else {
+ "without changelog checkpointing"
+ }
+
+ def testWithChangelogConfig(testName: String)(testFun: => Unit): Unit = {
+ test(s"$testName ($testSuffix)") {
+ withSQLConf(
+
"spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled" ->
+ changelogCheckpointingEnabled.toString) {
+ testFun
+ }
+ }
+ }
+
+ testWithChangelogConfig("SPARK-54388: simple aggregation state ver 1") {
+ withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key ->
"1") {
+ withTempDir { tempDir =>
+ runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath)
+
+ val keySchema = StructType(Array(StructField("groupKey", IntegerType,
nullable = false)))
+ // State version 1 includes key columns in the value
+ val valueSchema = StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+
+ val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
+ val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+
+ validateBytesReadDfSchema(bytesDf)
+ compareNormalAndBytesData(normalData, bytesDf.collect(), "default",
keySchema, valueSchema)
+ }
+ }
+ }
+
+ testWithChangelogConfig("SPARK-54388: simple aggregation state ver 2") {
+ withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key ->
"2") {
+ withTempDir { tempDir =>
+ runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath)
+
+ val keySchema = StructType(Array(StructField("groupKey", IntegerType,
nullable = false)))
+ val valueSchema = StructType(Array(
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+
+ val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
+ val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+
+ validateBytesReadDfSchema(bytesDf)
+ compareNormalAndBytesData(normalData, bytesDf.collect(), "default",
keySchema, valueSchema)
+ }
+ }
+ }
+
+ testWithChangelogConfig("SPARK-54388: composite key aggregation state ver
1") {
+ withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key ->
"1") {
+ withTempDir { tempDir =>
+ runCompositeKeyStreamingAggregationQuery(tempDir.getAbsolutePath)
+
+ val keySchema = StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("fruit", org.apache.spark.sql.types.StringType, nullable
= true)
+ ))
+ // State version 1 includes key columns in the value
+ val valueSchema = StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("fruit", org.apache.spark.sql.types.StringType, nullable
= true),
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+
+ val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
+ val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+
+ validateBytesReadDfSchema(bytesDf)
+ compareNormalAndBytesData(normalData, bytesDf.collect(), "default",
keySchema, valueSchema)
+ }
+ }
+ }
+
+ testWithChangelogConfig("SPARK-54388: composite key aggregation state ver
2") {
+ withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key ->
"2") {
+ withTempDir { tempDir =>
+ runCompositeKeyStreamingAggregationQuery(tempDir.getAbsolutePath)
+
+ val keySchema = StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("fruit", org.apache.spark.sql.types.StringType, nullable
= true)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+
+ val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
+ val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+
+ validateBytesReadDfSchema(bytesDf)
+ compareNormalAndBytesData(normalData, bytesDf.collect(), "default",
keySchema, valueSchema)
+ }
+ }
+ }
+
+ testWithChangelogConfig("SPARK-54388: dropDuplicates validation") {
+ withTempDir { tempDir =>
+ runDropDuplicatesQuery(tempDir.getAbsolutePath)
+
+ val keySchema = StructType(Array(
+ StructField("value", IntegerType, nullable = false),
+ StructField("eventTime", org.apache.spark.sql.types.TimestampType)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("__dummy__", NullType, nullable = true)
+ ))
+
+ val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
+ val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+
+ validateBytesReadDfSchema(bytesDf)
+ compareNormalAndBytesData(normalData, bytesDf.collect(), "default",
keySchema, valueSchema)
+ }
+ }
+
+ testWithChangelogConfig("SPARK-54388: dropDuplicates with column
specified") {
+ withTempDir { tempDir =>
+ runDropDuplicatesQueryWithColumnSpecified(tempDir.getAbsolutePath)
+
+ val keySchema = StructType(Array(
+ StructField("col1", org.apache.spark.sql.types.StringType, nullable
= true)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("__dummy__", NullType, nullable = true)
+ ))
+
+ val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
+ val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+
+ validateBytesReadDfSchema(bytesDf)
+ compareNormalAndBytesData(normalData, bytesDf.collect(), "default",
keySchema, valueSchema)
+ }
+ }
+
+ testWithChangelogConfig("SPARK-54388: dropDuplicatesWithinWatermark") {
+ withTempDir { tempDir =>
+ runDropDuplicatesWithinWatermarkQuery(tempDir.getAbsolutePath)
+
+ val keySchema = StructType(Array(
+ StructField("_1", org.apache.spark.sql.types.StringType, nullable =
true)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("expiresAtMicros", LongType, nullable = false)
+ ))
+
+ val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
+ val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+
+ validateBytesReadDfSchema(bytesDf)
+ compareNormalAndBytesData(normalData, bytesDf.collect(), "default",
keySchema, valueSchema)
+ }
+ }
+
+ testWithChangelogConfig("SPARK-54388: session window aggregation") {
+ withTempDir { tempDir =>
+ runSessionWindowAggregationQuery(tempDir.getAbsolutePath)
+
+ val keySchema = StructType(Array(
+ StructField("sessionId", org.apache.spark.sql.types.StringType,
nullable = false),
+ StructField("sessionStartTime",
+ org.apache.spark.sql.types.TimestampType, nullable = false)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("session_window",
org.apache.spark.sql.types.StructType(Array(
+ StructField("start", org.apache.spark.sql.types.TimestampType),
+ StructField("end", org.apache.spark.sql.types.TimestampType)
+ )), nullable = false),
+ StructField("sessionId", org.apache.spark.sql.types.StringType,
nullable = false),
+ StructField("count", LongType, nullable = false)
+ ))
+
+ val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
+ val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+
+ validateBytesReadDfSchema(bytesDf)
+ compareNormalAndBytesData(normalData, bytesDf.collect(), "default",
keySchema, valueSchema)
+ }
+ }
+
+ testWithChangelogConfig("SPARK-54388: flatMapGroupsWithState, state ver
1") {
+ // Skip this test on big endian platforms
+
assume(java.nio.ByteOrder.nativeOrder().equals(java.nio.ByteOrder.LITTLE_ENDIAN))
+ withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key ->
"1") {
+ withTempDir { tempDir =>
+ assume(ByteOrder.nativeOrder().equals(ByteOrder.LITTLE_ENDIAN))
+ runFlatMapGroupsWithStateQuery(tempDir.getAbsolutePath)
+
+ val keySchema = StructType(Array(
+ StructField("value", org.apache.spark.sql.types.StringType,
nullable = true)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("numEvents", IntegerType, nullable = false),
+ StructField("startTimestampMs", LongType, nullable = false),
+ StructField("endTimestampMs", LongType, nullable = false),
+ StructField("timeoutTimestamp", IntegerType, nullable = false)
+ ))
+
+ val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
+ val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+
+ validateBytesReadDfSchema(bytesDf)
+ compareNormalAndBytesData(
+ normalData, bytesDf.collect(), "default", keySchema, valueSchema)
+ }
+ }
+ }
+
+ testWithChangelogConfig("SPARK-54388: flatMapGroupsWithState, state ver
2") {
+ withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key ->
"2") {
+ withTempDir { tempDir =>
+ runFlatMapGroupsWithStateQuery(tempDir.getAbsolutePath)
+
+ val keySchema = StructType(Array(
+ StructField("value", org.apache.spark.sql.types.StringType,
nullable = true)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("groupState",
org.apache.spark.sql.types.StructType(Array(
+ StructField("numEvents", IntegerType, nullable = false),
+ StructField("startTimestampMs", LongType, nullable = false),
+ StructField("endTimestampMs", LongType, nullable = false)
+ )), nullable = false),
+ StructField("timeoutTimestamp", LongType, nullable = false)
+ ))
+
+ val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
+ val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+
+ validateBytesReadDfSchema(bytesDf)
+ compareNormalAndBytesData(
+ normalData, bytesDf.collect(), "default", keySchema, valueSchema)
+ }
+ }
+ }
+
+ def testStreamStreamJoin(stateVersion: Int): Unit = {
+ withSQLConf(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key ->
stateVersion.toString) {
+ withTempDir { tempDir =>
+ runStreamStreamJoinQuery(tempDir.getAbsolutePath)
+
+ Seq("right-keyToNumValues", "left-keyToNumValues").foreach(storeName
=> {
+ val stateReaderForRight = getNormalReadDf(
+ tempDir.getAbsolutePath, Option(storeName))
+ val stateBytesDfForRight = getBytesReadDf(
+ tempDir.getAbsolutePath, Option(storeName))
+
+ val keyToNumValuesKeySchema = StructType(Array(
+ StructField("key", IntegerType)
+ ))
+ val keyToNumValueValueSchema = StructType(Array(
+ StructField("value", LongType)
+ ))
+
+ validateBytesReadDfSchema(stateBytesDfForRight)
+ compareNormalAndBytesData(
+ stateReaderForRight.collect(),
+ stateBytesDfForRight.collect(),
+ StateStore.DEFAULT_COL_FAMILY_NAME,
+ keyToNumValuesKeySchema,
+ keyToNumValueValueSchema)
+ })
+
+ Seq("right-keyWithIndexToValue",
"left-keyWithIndexToValue").foreach(storeName => {
+ val stateReaderForRight = getNormalReadDf(
+ tempDir.getAbsolutePath, Option(storeName))
+ val stateBytesDfForRight = getBytesReadDf(
+ tempDir.getAbsolutePath, Option(storeName))
+
+ val keyToNumValuesKeySchema = StructType(Array(
+ StructField("key", IntegerType, nullable = false),
+ StructField("index", LongType)
+ ))
+ val keyToNumValueValueSchema = if (stateVersion == 2) {
+ StructType(Array(
+ StructField("value", IntegerType, nullable = false),
+ StructField("time", TimestampType, nullable = false),
+ StructField("matched", BooleanType)
+ ))
+ } else {
+ StructType(Array(
+ StructField("value", IntegerType, nullable = false),
+ StructField("time", TimestampType, nullable = false)
+ ))
+ }
+
+ validateBytesReadDfSchema(stateBytesDfForRight)
+ compareNormalAndBytesData(
+ stateReaderForRight.collect(),
+ stateBytesDfForRight.collect(),
+ StateStore.DEFAULT_COL_FAMILY_NAME,
+ keyToNumValuesKeySchema,
+ keyToNumValueValueSchema)
+ })
+ }
+ }
+ }
+
+ testWithChangelogConfig("stream-stream join, state ver 1") {
+ testStreamStreamJoin(1)
+ }
+
+ testWithChangelogConfig("stream-stream join, state ver 2") {
+ testStreamStreamJoin(2)
+ }
+ } // End of foreach loop for changelog checkpointing dimension
+
+ test("internalOnlyReadAllColumnFamilies should fail with HDFS-backed state
store") {
+ withTempDir { tempDir =>
+ withSQLConf(
+ SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[HDFSBackedStateStoreProvider].getName,
+ SQLConf.SHUFFLE_PARTITIONS.key -> "2") {
+
+ val inputData = MemoryStream[Int]
+ val aggregated = inputData.toDF()
+ .selectExpr("value", "value % 10 AS groupKey")
+ .groupBy($"groupKey")
+ .agg(
+ count("*").as("cnt"),
+ sum("value").as("sum")
+ )
+ .as[(Int, Long, Long)]
+
+ testStream(aggregated, OutputMode.Update)(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath),
+ AddData(inputData, 0 until 1: _*),
+ CheckLastBatch((0, 1, 0)),
+ StopStream
+ )
+
+ checkError(
+ exception = intercept[StateRepartitionUnsupportedProviderError] {
+ spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+
.option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, "true")
+ .load()
+ .collect()
+ },
+ condition =
"STATE_REPARTITION_INVALID_CHECKPOINT.UNSUPPORTED_PROVIDER",
+ parameters = Map(
+ "checkpointLocation" -> s".*${tempDir.getAbsolutePath}",
+ "provider" -> classOf[HDFSBackedStateStoreProvider].getName
+ ),
+ matchPVals = true
+ )
+ }
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]