zifeif2 commented on code in PR #53104:
URL: https://github.com/apache/spark/pull/53104#discussion_r2550736299
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala:
##########
@@ -60,6 +61,17 @@ object SchemaUtil {
.add("key", keySchema)
.add("value", valueSchema)
.add("partition_id", IntegerType)
+ } else if (sourceOptions.internalOnlyReadAllColumnFamilies) {
+ new StructType()
+ // todo: change this to some more specific type after we
+ // can extract partition key from keySchema
+ .add("partition_key", keySchema)
+ .add("key_bytes", BinaryType)
+ .add("value_bytes", BinaryType)
+ .add("column_family_name", StringType)
+ // need key and value schema so that state store can encode data
+ .add("value", valueSchema)
Review Comment:
I thought RocksDBStateStoreProvider will need valueSchema and keySchema to
pass `validateStateRowFormat` after getting raw bytes from RocksDB. But looks
like we can bypass the validation by setting `formatValidationEnabled` and
`formatValidationCheckValue` in StateStoreConf to false, so I am going to do so
in the next version
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -237,6 +240,44 @@ class StatePartitionReader(
}
}
+/**
+ * An implementation of [[StatePartitionReaderBase]] for reading all column
families
+ * in binary format. This reader returns raw key and value bytes along with
column family names.
+ */
+class StatePartitionReaderAllColumnFamilies(
+ storeConf: StateStoreConf,
+ hadoopConf: SerializableConfiguration,
+ partition: StateStoreInputPartition,
+ schema: StructType,
+ keyStateEncoderSpec: KeyStateEncoderSpec)
+ extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema,
+ keyStateEncoderSpec, None, None, None, None) {
+
+ private lazy val store: ReadStateStore = {
+ assert(getStartStoreUniqueId == getEndStoreUniqueId,
+ "Start and end store unique IDs must be the same when reading all column
families")
+ provider.getReadStore(
+ partition.sourceOptions.batchId + 1,
+ getStartStoreUniqueId
+ )
+ }
+
+ override lazy val iter: Iterator[InternalRow] = {
+ // Single store with column families (join v3, transformWithState, or
simple operators)
Review Comment:
It's not supposed to be there. I'll remove it
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala:
##########
@@ -65,6 +65,13 @@ class StateDataSource extends TableProvider with
DataSourceRegister with Logging
val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf,
StateSourceOptions.apply(session, hadoopConf, properties))
val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation,
sourceOptions.batchId)
+ if (sourceOptions.internalOnlyReadAllColumnFamilies
+ && !stateConf.providerClass.contains("RocksDB")) {
+ throw StateDataSourceErrors.invalidOptionValue(
+ StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES,
Review Comment:
true! Added a new error class in OfflineStateRepartitionErrors in next
version
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala:
##########
@@ -0,0 +1,244 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state
+
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.catalyst.expressions.{GenericRowWithSchema,
UnsafeRow}
+import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
+import
org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider,
RocksDBStateStoreProvider}
+import org.apache.spark.sql.functions.{count, sum}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.{IntegerType, LongType, StructField,
StructType}
+import org.apache.spark.tags.SlowSQLTest
+import org.apache.spark.unsafe.Platform
+
+/**
+ * Test suite to verify StatePartitionReaderAllColumnFamilies functionality.
+ */
+@SlowSQLTest
+class StatePartitionReaderAllColumnFamiliesSuite extends
StateDataSourceTestBase {
+
+ import testImplicits._
+
+ /**
+ * Returns a set of (partitionId, key, value) tuples from a normal state
read.
+ */
+ private def getNormalReadData(checkpointDir: String): DataFrame = {
+ spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, checkpointDir)
+ .load()
+ .selectExpr("partition_id", "key", "value")
+ }
+
+ /**
+ * Returns a DataFrame with raw bytes mode
(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES = true).
+ */
+ private def getBytesReadDf(checkpointDir: String): DataFrame = {
+ spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, checkpointDir)
+ .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES,
"true")
+ .load()
+ }
+
+ /**
+ * Validates the schema and column families of the bytes read DataFrame.
+ */
+ private def validateBytesReadSchema(df: DataFrame): Unit = {
+ // Verify schema
+ val schema = df.schema
+ assert(schema.fieldNames === Array(
+ "partition_key", "key_bytes", "value_bytes", "column_family_name",
"value", "key"))
+ assert(schema("partition_key").dataType.typeName === "struct")
+ assert(schema("key_bytes").dataType.typeName === "binary")
+ assert(schema("value_bytes").dataType.typeName === "binary")
+ assert(schema("column_family_name").dataType.typeName === "string")
+ }
+
+ private def parseBytesReadData(df: Array[Row], keyLength: Int, valueLength:
Int)
+ : Set[(GenericRowWithSchema, UnsafeRow, UnsafeRow, String)] = {
+ df.map { row =>
+ val partitionKey = row.getAs[GenericRowWithSchema](0)
+ val keyBytes = row.getAs[Array[Byte]](1)
+ val valueBytes = row.getAs[Array[Byte]](2)
+ val columnFamily = row.getString(3)
+
+ // Deserialize key bytes to UnsafeRow
+ val keyRow = new UnsafeRow(keyLength)
+ keyRow.pointTo(
+ keyBytes,
+ Platform.BYTE_ARRAY_OFFSET,
+ keyBytes.length)
+
+ // Deserialize value bytes to UnsafeRow
+ val valueRow = new UnsafeRow(valueLength)
+ valueRow.pointTo(
+ valueBytes,
+ Platform.BYTE_ARRAY_OFFSET,
+ valueBytes.length)
+ (partitionKey, keyRow.copy(), valueRow.copy(), columnFamily)
+ }
+ .toSet
+ }
+
+ /**
+ * Compares normal read data with bytes read data for a specific column
family.
+ */
+ private def compareNormalAndBytesData(
+ normalReadDf: DataFrame,
+ bytesReadDf: DataFrame,
+ columnFamily: String,
+ keySchema: StructType,
+ valueSchema: StructType): Unit = {
+ // Verify data
+ val bytesDf = bytesReadDf
+ .selectExpr("partition_key", "key_bytes", "value_bytes",
"column_family_name")
+ .collect()
+ assert(bytesDf.length == 10,
+ s"Expected 10 rows but got: ${bytesDf.length}")
+
+ // Filter bytes data for the specified column family
+ val bytesData = parseBytesReadData(bytesDf, keySchema.length,
valueSchema.length)
+ val filteredBytesData = bytesData.filter(_._4 == columnFamily)
+
+ // Apply the projection
+ // Convert to comparable format (extract field values)
+ val normalSet = normalReadDf.collect().map { row =>
+ val key = row.getStruct(1)
+ val value = row.getStruct(2)
+ val keyFields = (0 until key.length).map(i => key.get(i))
+ val valueFields = (0 until value.length).map(i => value.get(i))
+ (keyFields, valueFields)
+ }.toSet
+
+ val bytesSet = filteredBytesData.map { case (_, keyRow, valueRow, _) =>
+ val keyFields = (0 until keySchema.length).map(i =>
+ keyRow.get(i, keySchema(i).dataType))
+ val valueFields = (0 until valueSchema.length).map(i =>
+ valueRow.get(i, valueSchema(i).dataType))
+ (keyFields, valueFields)
+ }
+ // Verify same number of rows
+ assert(filteredBytesData.size == normalSet.size,
Review Comment:
Refactored in the next version so that it's more obvious to compare the
"number of rows from rawBytesData that belongs to a column family" against
"normal dataframe"
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala:
##########
@@ -76,6 +88,22 @@ object SchemaUtil {
row
}
+ /**
+ * Creates a unified row from raw key and value bytes.
+ * This is an alias for unifyStateRowPairAsBytes that takes individual byte
arrays
+ * instead of a tuple for better readability.
+ */
+ def unifyStateRowPairAsRawBytes(
+ pair: (UnsafeRow, UnsafeRow),
+ colFamilyName: String): InternalRow = {
+ val row = new GenericInternalRow(6)
Review Comment:
Because previously the schema has 6 fields but we only need 4 (as we didn't
need key and value for repartitioning). Next version removes key and value
columns from schema, so I'll change this to 4 instead.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]