micheal-o commented on code in PR #53104:
URL: https://github.com/apache/spark/pull/53104#discussion_r2567181951
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -78,19 +83,39 @@ abstract class StatePartitionReaderBase(
extends PartitionReader[InternalRow] with Logging {
// Used primarily as a placeholder for the value schema in the context of
// state variables used within the transformWithState operator.
- private val schemaForValueRow: StructType =
+ // Also used as a placeholder for both key and value schema for
+ // StatePartitionAllColumnFamiliesReader
+ private val placeholderSchema: StructType =
StructType(Array(StructField("__dummy__", NullType)))
+ private val colFamilyToSchema : mutable.HashMap[String,
StateStoreColFamilySchema] = {
+ val stateStoreId = StateStoreId(
+ partition.sourceOptions.stateCheckpointLocation.toString,
+ partition.sourceOptions.operatorId,
+ StateStore.PARTITION_ID_TO_CHECK_SCHEMA,
+ partition.sourceOptions.storeName)
+ val stateStoreProviderId = StateStoreProviderId(stateStoreId,
partition.queryId)
+ val manager = new StateSchemaCompatibilityChecker(stateStoreProviderId,
hadoopConf.value)
Review Comment:
Why duplicate the code from `getStoreMetadataAndRunChecks` here? We already
read the schema file in the StateDataSource i.e. in `inferSchema` and
`getStoreMetadataAndRunChecks` methods, and we pass it in here via
`stateStoreColFamilySchemaOpt` param. So we can rely on
stateStoreColFamilySchemaOpt here.
Also, since this PR is for single CF, then lets focus on single CF for now.
In the subsequent PR, we can then pass in schema for all CFs.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala:
##########
@@ -201,3 +208,10 @@ class StateRepartitionUnsupportedOffsetSeqVersionError(
checkpointLocation,
subClass = "UNSUPPORTED_OFFSET_SEQ_VERSION",
messageParameters = Map("version" -> version.toString))
+
+class StateRepartitionUnsupportedProviderError(
+ checkpointLocation: String,
+ provider: String) extends StateRepartitionInvalidCheckpointError(
Review Comment:
Please reformat. See how the others are formatted. Please make sure you run
scalafmt before submitting for review
([doc](https://spark.apache.org/developer-tools.html#:~:text=the%20style%20guide.-,Formatting%20code,-To%20format%20Scala))
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala:
##########
@@ -0,0 +1,523 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state
+
+import java.nio.ByteOrder
+import java.util.Arrays
+
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
+import
org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider,
RocksDBStateStoreProvider, StateRepartitionUnsupportedProviderError, StateStore}
+import org.apache.spark.sql.functions.{count, sum}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType,
NullType, StructField, StructType, TimestampType}
+
+/**
+ * Note: This extends StateDataSourceTestBase to access
+ * helper methods like runDropDuplicatesQuery without inheriting all
predefined tests.
+ */
+class StatePartitionAllColumnFamiliesReaderSuite extends
StateDataSourceTestBase {
+
+ import testImplicits._
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
+ classOf[RocksDBStateStoreProvider].getName)
+ }
+
+ private def getNormalReadDf(
+ checkpointDir: String,
+ storeName: Option[String] = Option.empty[String]): DataFrame = {
+ spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, checkpointDir)
+ .option(StateSourceOptions.STORE_NAME, storeName.orNull)
+ .load()
+ .selectExpr("partition_id", "key", "value")
+ }
+
+ private def getBytesReadDf(
+ checkpointDir: String,
+ storeName: Option[String] = Option.empty[String]): DataFrame = {
+ spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, checkpointDir)
+ .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES,
"true")
+ .option(StateSourceOptions.STORE_NAME, storeName.orNull)
+ .load()
+ .selectExpr("partition_key", "key_bytes", "value_bytes",
"column_family_name")
+ }
+
+ /**
+ * Validates the schema and column families of the bytes read DataFrame.
+ */
+ private def validateBytesReadDfSchema(df: DataFrame): Unit = {
+ // Verify schema
+ val schema = df.schema
+ assert(schema.fieldNames === Array(
+ "partition_key", "key_bytes", "value_bytes", "column_family_name"))
+ assert(schema("partition_key").dataType.typeName === "struct")
+ assert(schema("key_bytes").dataType.typeName === "binary")
+ assert(schema("value_bytes").dataType.typeName === "binary")
+ assert(schema("column_family_name").dataType.typeName === "string")
+ }
+
+ /**
+ * Compares normal read data with bytes read data for a specific column
family.
+ * Converts normal rows to bytes then compares with bytes read.
+ */
+ private def compareNormalAndBytesData(
+ normalDf: Array[Row],
+ bytesDf: Array[Row],
+ columnFamily: String,
+ keySchema: StructType,
+ valueSchema: StructType): Unit = {
+
+ // Filter bytes data for the specified column family and extract raw bytes
directly
+ val filteredBytesData = bytesDf.filter { row =>
+ row.getString(3) == columnFamily
+ }
+
+ // Verify same number of rows
+ assert(filteredBytesData.length == normalDf.length,
+ s"Row count mismatch for column family '$columnFamily': " +
+ s"normal read has ${normalDf.length} rows, " +
+ s"bytes read has ${filteredBytesData.length} rows")
+
+ // Create projections to convert Row to UnsafeRow bytes
+ val keyProjection = UnsafeProjection.create(keySchema)
+ val valueProjection = UnsafeProjection.create(valueSchema)
+
+ // Create converters to convert external Row types to internal Catalyst
types
+ val keyConverter =
CatalystTypeConverters.createToCatalystConverter(keySchema)
+ val valueConverter =
CatalystTypeConverters.createToCatalystConverter(valueSchema)
+
+ // Convert normal data to bytes
+ val normalAsBytes = normalDf.toSeq.map { row =>
+ val key = row.getStruct(1)
+ val value = if (row.isNullAt(2)) null else row.getStruct(2)
+
+ // Convert key to InternalRow, then to UnsafeRow, then get bytes
+ val keyInternalRow = keyConverter(key).asInstanceOf[InternalRow]
+ val keyUnsafeRow = keyProjection(keyInternalRow)
+ // IMPORTANT: Must clone the bytes array since getBytes() returns a
reference
+ // that may be overwritten by subsequent UnsafeRow operations
+ val keyBytes = keyUnsafeRow.getBytes.clone()
+
+ // Convert value to bytes
+ val valueBytes = if (value == null) {
+ Array.empty[Byte]
+ } else {
+ val valueInternalRow = valueConverter(value).asInstanceOf[InternalRow]
+ val valueUnsafeRow = valueProjection(valueInternalRow)
+ // IMPORTANT: Must clone the bytes array since getBytes() returns a
reference
+ // that may be overwritten by subsequent UnsafeRow operations
+ valueUnsafeRow.getBytes.clone()
+ }
+
+ (keyBytes, valueBytes)
+ }
+
+ // Extract raw bytes from bytes read data (no
deserialization/reserialization)
+ val bytesAsBytes = filteredBytesData.map { row =>
+ val keyBytes = row.getAs[Array[Byte]](1)
+ val valueBytes = row.getAs[Array[Byte]](2)
+ (keyBytes, valueBytes)
+ }
+
+ // Sort both for comparison (since Set equality doesn't work well with
byte arrays)
+ val normalSorted = normalAsBytes.sortBy(x => (x._1.mkString(","),
x._2.mkString(",")))
+ val bytesSorted = bytesAsBytes.sortBy(x => (x._1.mkString(","),
x._2.mkString(",")))
+
+ assert(normalSorted.length == bytesSorted.length,
+ s"Size mismatch: normal has ${normalSorted.length}, bytes has
${bytesSorted.length}")
+
+ // Compare each pair
+ normalSorted.zip(bytesSorted).zipWithIndex.foreach {
+ case (((normalKey, normalValue), (bytesKey, bytesValue)), idx) =>
+ assert(Arrays.equals(normalKey, bytesKey),
+ s"Key mismatch at index $idx:\n" +
+ s" Normal: ${normalKey.mkString("[", ",", "]")}\n" +
+ s" Bytes: ${bytesKey.mkString("[", ",", "]")}")
+ assert(Arrays.equals(normalValue, bytesValue),
+ s"Value mismatch at index $idx:\n" +
+ s" Normal: ${normalValue.mkString("[", ",", "]")}\n" +
+ s" Bytes: ${bytesValue.mkString("[", ",", "]")}")
+ }
+ }
+
+ // Run all tests with both changelog checkpointing enabled and disabled
+ Seq(true, false).foreach { changelogCheckpointingEnabled =>
+ val testSuffix = if (changelogCheckpointingEnabled) {
+ "with changelog checkpointing"
+ } else {
+ "without changelog checkpointing"
+ }
+
+ def testWithChangelogConfig(testName: String)(testFun: => Unit): Unit = {
+ test(s"$testName ($testSuffix)") {
+ withSQLConf(
+
"spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled" ->
+ changelogCheckpointingEnabled.toString) {
+ testFun
+ }
+ }
+ }
+
+ testWithChangelogConfig("SPARK-54388: simple aggregation state ver 1") {
+ withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key ->
"1") {
+ withTempDir { tempDir =>
+ runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath)
+
+ val keySchema = StructType(Array(StructField("groupKey", IntegerType,
nullable = false)))
+ // State version 1 includes key columns in the value
+ val valueSchema = StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+
+ val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
+ val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+
+ validateBytesReadDfSchema(bytesDf)
+ compareNormalAndBytesData(normalData, bytesDf.collect(), "default",
keySchema, valueSchema)
+ }
+ }
+ }
+
+ testWithChangelogConfig("SPARK-54388: simple aggregation state ver 2") {
+ withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key ->
"2") {
+ withTempDir { tempDir =>
+ runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath)
+
+ val keySchema = StructType(Array(StructField("groupKey", IntegerType,
nullable = false)))
+ val valueSchema = StructType(Array(
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+
+ val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
+ val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect()
+
+ compareNormalAndBytesData(normalData, bytesDf, "default", keySchema,
valueSchema)
Review Comment:
We are not validating the bytesDf i.e. `validateBytesReadDfSchema`. Here and
the other test cases below
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala:
##########
@@ -492,6 +503,29 @@ object StateSourceOptions extends DataSourceOptions {
val readChangeFeed =
Option(options.get(READ_CHANGE_FEED)).exists(_.toBoolean)
+ val internalOnlyReadAllColumnFamilies = try {
+
Option(options.get(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES)).exists(_.toBoolean)
+ } catch {
+ case _: IllegalArgumentException =>
+ throw
StateDataSourceErrors.invalidOptionValue(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES,
+ "Boolean value is expected")
+ }
+
+ if (internalOnlyReadAllColumnFamilies && stateVarName.isDefined) {
+ throw StateDataSourceErrors.conflictOptions(
+ Seq(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, STATE_VAR_NAME))
+ }
+
+ if (internalOnlyReadAllColumnFamilies && joinSide != JoinSideValues.none) {
Review Comment:
nit: add comment that for this option, we use `storeName` to specify the
join store instead
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala:
##########
@@ -0,0 +1,523 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state
+
+import java.nio.ByteOrder
+import java.util.Arrays
+
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
+import
org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider,
RocksDBStateStoreProvider, StateRepartitionUnsupportedProviderError, StateStore}
+import org.apache.spark.sql.functions.{count, sum}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType,
NullType, StructField, StructType, TimestampType}
+
+/**
+ * Note: This extends StateDataSourceTestBase to access
+ * helper methods like runDropDuplicatesQuery without inheriting all
predefined tests.
+ */
+class StatePartitionAllColumnFamiliesReaderSuite extends
StateDataSourceTestBase {
+
+ import testImplicits._
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
+ classOf[RocksDBStateStoreProvider].getName)
+ }
+
+ private def getNormalReadDf(
+ checkpointDir: String,
+ storeName: Option[String] = Option.empty[String]): DataFrame = {
+ spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, checkpointDir)
+ .option(StateSourceOptions.STORE_NAME, storeName.orNull)
+ .load()
+ .selectExpr("partition_id", "key", "value")
+ }
+
+ private def getBytesReadDf(
+ checkpointDir: String,
+ storeName: Option[String] = Option.empty[String]): DataFrame = {
+ spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, checkpointDir)
+ .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES,
"true")
+ .option(StateSourceOptions.STORE_NAME, storeName.orNull)
+ .load()
+ .selectExpr("partition_key", "key_bytes", "value_bytes",
"column_family_name")
Review Comment:
nit: select is not needed right
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala:
##########
@@ -201,3 +208,10 @@ class StateRepartitionUnsupportedOffsetSeqVersionError(
checkpointLocation,
subClass = "UNSUPPORTED_OFFSET_SEQ_VERSION",
messageParameters = Map("version" -> version.toString))
+
+class StateRepartitionUnsupportedProviderError(
Review Comment:
please fix indentation
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]