micheal-o commented on code in PR #53104:
URL: https://github.com/apache/spark/pull/53104#discussion_r2548006075


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala:
##########
@@ -65,6 +65,13 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
     val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf,
       StateSourceOptions.apply(session, hadoopConf, properties))
     val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, 
sourceOptions.batchId)
+    if (sourceOptions.internalOnlyReadAllColumnFamilies
+      && !stateConf.providerClass.contains("RocksDB")) {

Review Comment:
   nit: Just compare the class. Look at how we do this in other places



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -237,6 +240,44 @@ class StatePartitionReader(
   }
 }
 
+/**
+ * An implementation of [[StatePartitionReaderBase]] for reading all column 
families
+ * in binary format. This reader returns raw key and value bytes along with 
column family names.

Review Comment:
   nit: Add that "We are returning key/value bytes because each column family 
can have different schema"



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala:
##########
@@ -76,6 +88,22 @@ object SchemaUtil {
     row
   }
 
+  /**
+   * Creates a unified row from raw key and value bytes.
+   * This is an alias for unifyStateRowPairAsBytes that takes individual byte 
arrays
+   * instead of a tuple for better readability.
+   */
+  def unifyStateRowPairAsRawBytes(
+     pair: (UnsafeRow, UnsafeRow),
+     colFamilyName: String): InternalRow = {
+    val row = new GenericInternalRow(6)

Review Comment:
   This is 6, but you're only updating 4?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -237,6 +240,44 @@ class StatePartitionReader(
   }
 }
 
+/**
+ * An implementation of [[StatePartitionReaderBase]] for reading all column 
families
+ * in binary format. This reader returns raw key and value bytes along with 
column family names.
+ */
+class StatePartitionReaderAllColumnFamilies(

Review Comment:
   nit: `StatePartitionAllColumnFamiliesReader`



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala:
##########
@@ -0,0 +1,244 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state
+
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.catalyst.expressions.{GenericRowWithSchema, 
UnsafeRow}
+import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
+import 
org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, 
RocksDBStateStoreProvider}
+import org.apache.spark.sql.functions.{count, sum}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.{IntegerType, LongType, StructField, 
StructType}
+import org.apache.spark.tags.SlowSQLTest
+import org.apache.spark.unsafe.Platform
+
+/**
+ * Test suite to verify StatePartitionReaderAllColumnFamilies functionality.
+ */
+@SlowSQLTest
+class StatePartitionReaderAllColumnFamiliesSuite extends 
StateDataSourceTestBase {
+
+  import testImplicits._
+
+  /**
+   * Returns a set of (partitionId, key, value) tuples from a normal state 
read.
+   */
+  private def getNormalReadData(checkpointDir: String): DataFrame = {

Review Comment:
   nit: `getNormalReadDf`



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala:
##########
@@ -0,0 +1,244 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state
+
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.catalyst.expressions.{GenericRowWithSchema, 
UnsafeRow}
+import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
+import 
org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, 
RocksDBStateStoreProvider}
+import org.apache.spark.sql.functions.{count, sum}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.{IntegerType, LongType, StructField, 
StructType}
+import org.apache.spark.tags.SlowSQLTest
+import org.apache.spark.unsafe.Platform
+
+/**
+ * Test suite to verify StatePartitionReaderAllColumnFamilies functionality.
+ */
+@SlowSQLTest
+class StatePartitionReaderAllColumnFamiliesSuite extends 
StateDataSourceTestBase {
+
+  import testImplicits._
+
+  /**
+   * Returns a set of (partitionId, key, value) tuples from a normal state 
read.
+   */
+  private def getNormalReadData(checkpointDir: String): DataFrame = {
+    spark.read
+      .format("statestore")
+      .option(StateSourceOptions.PATH, checkpointDir)
+      .load()
+      .selectExpr("partition_id", "key", "value")
+  }
+
+  /**
+   * Returns a DataFrame with raw bytes mode 
(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES = true).
+   */
+  private def getBytesReadDf(checkpointDir: String): DataFrame = {
+    spark.read
+      .format("statestore")
+      .option(StateSourceOptions.PATH, checkpointDir)
+      .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, 
"true")
+      .load()
+  }
+
+  /**
+   * Validates the schema and column families of the bytes read DataFrame.
+   */
+  private def validateBytesReadSchema(df: DataFrame): Unit = {
+    // Verify schema
+    val schema = df.schema
+    assert(schema.fieldNames === Array(
+      "partition_key", "key_bytes", "value_bytes", "column_family_name", 
"value", "key"))
+    assert(schema("partition_key").dataType.typeName === "struct")
+    assert(schema("key_bytes").dataType.typeName === "binary")
+    assert(schema("value_bytes").dataType.typeName === "binary")
+    assert(schema("column_family_name").dataType.typeName === "string")
+  }
+
+  private def parseBytesReadData(df: Array[Row], keyLength: Int, valueLength: 
Int)
+    : Set[(GenericRowWithSchema, UnsafeRow, UnsafeRow, String)] = {
+    df.map { row =>
+        val partitionKey = row.getAs[GenericRowWithSchema](0)
+        val keyBytes = row.getAs[Array[Byte]](1)
+        val valueBytes = row.getAs[Array[Byte]](2)
+        val columnFamily = row.getString(3)
+
+        // Deserialize key bytes to UnsafeRow
+        val keyRow = new UnsafeRow(keyLength)
+        keyRow.pointTo(
+          keyBytes,
+          Platform.BYTE_ARRAY_OFFSET,
+          keyBytes.length)
+
+        // Deserialize value bytes to UnsafeRow
+        val valueRow = new UnsafeRow(valueLength)
+        valueRow.pointTo(
+          valueBytes,
+          Platform.BYTE_ARRAY_OFFSET,
+          valueBytes.length)
+        (partitionKey, keyRow.copy(), valueRow.copy(), columnFamily)
+      }
+      .toSet
+  }
+
+  /**
+   * Compares normal read data with bytes read data for a specific column 
family.
+   */
+  private def compareNormalAndBytesData(
+      normalReadDf: DataFrame,
+      bytesReadDf: DataFrame,
+      columnFamily: String,
+      keySchema: StructType,
+      valueSchema: StructType): Unit = {
+    // Verify data
+    val bytesDf = bytesReadDf
+      .selectExpr("partition_key", "key_bytes", "value_bytes", 
"column_family_name")
+      .collect()
+    assert(bytesDf.length == 10,
+      s"Expected 10 rows but got: ${bytesDf.length}")
+
+    // Filter bytes data for the specified column family
+    val bytesData = parseBytesReadData(bytesDf, keySchema.length, 
valueSchema.length)
+    val filteredBytesData = bytesData.filter(_._4 == columnFamily)
+
+    // Apply the projection
+    // Convert to comparable format (extract field values)
+    val normalSet = normalReadDf.collect().map { row =>
+      val key = row.getStruct(1)
+      val value = row.getStruct(2)
+      val keyFields = (0 until key.length).map(i => key.get(i))
+      val valueFields = (0 until value.length).map(i => value.get(i))
+      (keyFields, valueFields)
+    }.toSet
+
+    val bytesSet = filteredBytesData.map { case (_, keyRow, valueRow, _) =>
+      val keyFields = (0 until keySchema.length).map(i =>
+        keyRow.get(i, keySchema(i).dataType))
+      val valueFields = (0 until valueSchema.length).map(i =>
+        valueRow.get(i, valueSchema(i).dataType))
+      (keyFields, valueFields)
+    }
+    // Verify same number of rows
+    assert(filteredBytesData.size == normalSet.size,
+      s"Row count mismatch for column family '$columnFamily': " +
+        s"normal read has ${filteredBytesData.size} rows, bytes read has 
${normalSet.size} rows")

Review Comment:
   wrong, switch the vals?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala:
##########
@@ -0,0 +1,244 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state
+
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.catalyst.expressions.{GenericRowWithSchema, 
UnsafeRow}
+import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
+import 
org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, 
RocksDBStateStoreProvider}
+import org.apache.spark.sql.functions.{count, sum}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.{IntegerType, LongType, StructField, 
StructType}
+import org.apache.spark.tags.SlowSQLTest
+import org.apache.spark.unsafe.Platform
+
+/**
+ * Test suite to verify StatePartitionReaderAllColumnFamilies functionality.
+ */
+@SlowSQLTest
+class StatePartitionReaderAllColumnFamiliesSuite extends 
StateDataSourceTestBase {
+
+  import testImplicits._
+
+  /**
+   * Returns a set of (partitionId, key, value) tuples from a normal state 
read.
+   */
+  private def getNormalReadData(checkpointDir: String): DataFrame = {
+    spark.read
+      .format("statestore")
+      .option(StateSourceOptions.PATH, checkpointDir)
+      .load()
+      .selectExpr("partition_id", "key", "value")
+  }
+
+  /**
+   * Returns a DataFrame with raw bytes mode 
(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES = true).
+   */
+  private def getBytesReadDf(checkpointDir: String): DataFrame = {
+    spark.read
+      .format("statestore")
+      .option(StateSourceOptions.PATH, checkpointDir)
+      .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, 
"true")
+      .load()
+  }
+
+  /**
+   * Validates the schema and column families of the bytes read DataFrame.
+   */
+  private def validateBytesReadSchema(df: DataFrame): Unit = {
+    // Verify schema
+    val schema = df.schema
+    assert(schema.fieldNames === Array(
+      "partition_key", "key_bytes", "value_bytes", "column_family_name", 
"value", "key"))
+    assert(schema("partition_key").dataType.typeName === "struct")
+    assert(schema("key_bytes").dataType.typeName === "binary")
+    assert(schema("value_bytes").dataType.typeName === "binary")
+    assert(schema("column_family_name").dataType.typeName === "string")
+  }
+
+  private def parseBytesReadData(df: Array[Row], keyLength: Int, valueLength: 
Int)
+    : Set[(GenericRowWithSchema, UnsafeRow, UnsafeRow, String)] = {
+    df.map { row =>
+        val partitionKey = row.getAs[GenericRowWithSchema](0)
+        val keyBytes = row.getAs[Array[Byte]](1)
+        val valueBytes = row.getAs[Array[Byte]](2)
+        val columnFamily = row.getString(3)
+
+        // Deserialize key bytes to UnsafeRow
+        val keyRow = new UnsafeRow(keyLength)
+        keyRow.pointTo(
+          keyBytes,
+          Platform.BYTE_ARRAY_OFFSET,
+          keyBytes.length)
+
+        // Deserialize value bytes to UnsafeRow
+        val valueRow = new UnsafeRow(valueLength)
+        valueRow.pointTo(
+          valueBytes,
+          Platform.BYTE_ARRAY_OFFSET,
+          valueBytes.length)
+        (partitionKey, keyRow.copy(), valueRow.copy(), columnFamily)
+      }
+      .toSet
+  }
+
+  /**
+   * Compares normal read data with bytes read data for a specific column 
family.
+   */
+  private def compareNormalAndBytesData(
+      normalReadDf: DataFrame,
+      bytesReadDf: DataFrame,
+      columnFamily: String,
+      keySchema: StructType,
+      valueSchema: StructType): Unit = {
+    // Verify data
+    val bytesDf = bytesReadDf
+      .selectExpr("partition_key", "key_bytes", "value_bytes", 
"column_family_name")
+      .collect()
+    assert(bytesDf.length == 10,
+      s"Expected 10 rows but got: ${bytesDf.length}")
+
+    // Filter bytes data for the specified column family
+    val bytesData = parseBytesReadData(bytesDf, keySchema.length, 
valueSchema.length)
+    val filteredBytesData = bytesData.filter(_._4 == columnFamily)
+
+    // Apply the projection
+    // Convert to comparable format (extract field values)
+    val normalSet = normalReadDf.collect().map { row =>
+      val key = row.getStruct(1)
+      val value = row.getStruct(2)
+      val keyFields = (0 until key.length).map(i => key.get(i))
+      val valueFields = (0 until value.length).map(i => value.get(i))
+      (keyFields, valueFields)
+    }.toSet
+
+    val bytesSet = filteredBytesData.map { case (_, keyRow, valueRow, _) =>
+      val keyFields = (0 until keySchema.length).map(i =>
+        keyRow.get(i, keySchema(i).dataType))
+      val valueFields = (0 until valueSchema.length).map(i =>
+        valueRow.get(i, valueSchema(i).dataType))
+      (keyFields, valueFields)
+    }
+    // Verify same number of rows
+    assert(filteredBytesData.size == normalSet.size,
+      s"Row count mismatch for column family '$columnFamily': " +
+        s"normal read has ${filteredBytesData.size} rows, bytes read has 
${normalSet.size} rows")
+
+    assert(normalSet == bytesSet)
+  }
+
+    test(s"read all column families with simple operator") {

Review Comment:
   You need some more test cases to cover all operators that use single CF e.g. 
agg with composite key, DropDuplicate, DropDuplicateWithinWatermark, 
SessionWindow, FMGWS, etc. See `StateDataSourceReadSuite` it already has a 
bunch of query functions you can reuse and example test cases.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala:
##########
@@ -371,6 +378,7 @@ case class StateSourceOptions(
     stateVarName: Option[String],
     readRegisteredTimers: Boolean,
     flattenCollectionTypes: Boolean,
+    internalOnlyReadAllColumnFamilies: Boolean,

Review Comment:
   nit: default to false? 



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala:
##########
@@ -65,6 +65,13 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
     val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf,
       StateSourceOptions.apply(session, hadoopConf, properties))
     val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, 
sourceOptions.batchId)
+    if (sourceOptions.internalOnlyReadAllColumnFamilies
+      && !stateConf.providerClass.contains("RocksDB")) {
+      throw StateDataSourceErrors.invalidOptionValue(
+        StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES,

Review Comment:
   This is an internal only option, so we shouldn't expose it in user error 
message right?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala:
##########
@@ -60,6 +61,17 @@ object SchemaUtil {
         .add("key", keySchema)
         .add("value", valueSchema)
         .add("partition_id", IntegerType)
+    } else if (sourceOptions.internalOnlyReadAllColumnFamilies) {
+      new StructType()
+        // todo: change this to some more specific type after we
+        //  can extract partition key from keySchema
+        .add("partition_key", keySchema)
+        .add("key_bytes", BinaryType)
+        .add("value_bytes", BinaryType)
+        .add("column_family_name", StringType)
+        // need key and value schema so that state store can encode data
+        .add("value", valueSchema)

Review Comment:
   why do we need the `key` and `value` row here?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala:
##########
@@ -76,6 +88,22 @@ object SchemaUtil {
     row
   }
 
+  /**
+   * Creates a unified row from raw key and value bytes.
+   * This is an alias for unifyStateRowPairAsBytes that takes individual byte 
arrays

Review Comment:
   not sure what this line means?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -237,6 +240,44 @@ class StatePartitionReader(
   }
 }
 
+/**
+ * An implementation of [[StatePartitionReaderBase]] for reading all column 
families
+ * in binary format. This reader returns raw key and value bytes along with 
column family names.
+ */
+class StatePartitionReaderAllColumnFamilies(
+    storeConf: StateStoreConf,
+    hadoopConf: SerializableConfiguration,
+    partition: StateStoreInputPartition,
+    schema: StructType,
+    keyStateEncoderSpec: KeyStateEncoderSpec)
+  extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema,
+    keyStateEncoderSpec, None, None, None, None) {
+
+  private lazy val store: ReadStateStore = {
+    assert(getStartStoreUniqueId == getEndStoreUniqueId,
+      "Start and end store unique IDs must be the same when reading all column 
families")
+    provider.getReadStore(
+      partition.sourceOptions.batchId + 1,
+      getStartStoreUniqueId
+    )
+  }
+
+  override lazy val iter: Iterator[InternalRow] = {
+    // Single store with column families (join v3, transformWithState, or 
simple operators)

Review Comment:
   what does this comment mean?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala:
##########
@@ -0,0 +1,244 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state
+
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.catalyst.expressions.{GenericRowWithSchema, 
UnsafeRow}
+import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
+import 
org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, 
RocksDBStateStoreProvider}
+import org.apache.spark.sql.functions.{count, sum}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.{IntegerType, LongType, StructField, 
StructType}
+import org.apache.spark.tags.SlowSQLTest
+import org.apache.spark.unsafe.Platform
+
+/**
+ * Test suite to verify StatePartitionReaderAllColumnFamilies functionality.

Review Comment:
   ditto



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala:
##########
@@ -231,7 +259,11 @@ object SchemaUtil {
       "user_map_key" -> classOf[StructType],
       "user_map_value" -> classOf[StructType],
       "expiration_timestamp_ms" -> classOf[LongType],
-      "partition_id" -> classOf[IntegerType])
+      "partition_id" -> classOf[IntegerType],
+      "partition_key" -> classOf[StructType],
+      "key_bytes"->classOf[BinaryType],
+      "value_bytes"->classOf[BinaryType],
+      "column_family_name"->classOf[StringType])

Review Comment:
   nit: fix space between `->` for `key_bytes`, `value_bytes` and 
`column_family_name`



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala:
##########
@@ -76,6 +88,22 @@ object SchemaUtil {
     row
   }
 
+  /**
+   * Creates a unified row from raw key and value bytes.
+   * This is an alias for unifyStateRowPairAsBytes that takes individual byte 
arrays
+   * instead of a tuple for better readability.
+   */
+  def unifyStateRowPairAsRawBytes(
+     pair: (UnsafeRow, UnsafeRow),
+     colFamilyName: String): InternalRow = {
+    val row = new GenericInternalRow(6)
+    row.update(0, pair._1)

Review Comment:
   Add todo for setting the actual partition key



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala:
##########
@@ -76,6 +88,22 @@ object SchemaUtil {
     row
   }
 
+  /**
+   * Creates a unified row from raw key and value bytes.
+   * This is an alias for unifyStateRowPairAsBytes that takes individual byte 
arrays
+   * instead of a tuple for better readability.
+   */
+  def unifyStateRowPairAsRawBytes(
+     pair: (UnsafeRow, UnsafeRow),

Review Comment:
   nit: fix indentation



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala:
##########
@@ -0,0 +1,244 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state
+
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.catalyst.expressions.{GenericRowWithSchema, 
UnsafeRow}
+import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
+import 
org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, 
RocksDBStateStoreProvider}
+import org.apache.spark.sql.functions.{count, sum}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.{IntegerType, LongType, StructField, 
StructType}
+import org.apache.spark.tags.SlowSQLTest
+import org.apache.spark.unsafe.Platform
+
+/**
+ * Test suite to verify StatePartitionReaderAllColumnFamilies functionality.
+ */
+@SlowSQLTest
+class StatePartitionReaderAllColumnFamiliesSuite extends 
StateDataSourceTestBase {

Review Comment:
   ditto



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala:
##########
@@ -0,0 +1,244 @@
+/*

Review Comment:
   `StatePartitionAllColumnFamiliesReaderSuite.scala`



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala:
##########
@@ -0,0 +1,244 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state
+
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.catalyst.expressions.{GenericRowWithSchema, 
UnsafeRow}
+import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
+import 
org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, 
RocksDBStateStoreProvider}
+import org.apache.spark.sql.functions.{count, sum}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.{IntegerType, LongType, StructField, 
StructType}
+import org.apache.spark.tags.SlowSQLTest
+import org.apache.spark.unsafe.Platform
+
+/**
+ * Test suite to verify StatePartitionReaderAllColumnFamilies functionality.
+ */
+@SlowSQLTest

Review Comment:
   why is this marked as slow?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala:
##########
@@ -0,0 +1,244 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state
+
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.catalyst.expressions.{GenericRowWithSchema, 
UnsafeRow}
+import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
+import 
org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, 
RocksDBStateStoreProvider}
+import org.apache.spark.sql.functions.{count, sum}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.{IntegerType, LongType, StructField, 
StructType}
+import org.apache.spark.tags.SlowSQLTest
+import org.apache.spark.unsafe.Platform
+
+/**
+ * Test suite to verify StatePartitionReaderAllColumnFamilies functionality.
+ */
+@SlowSQLTest
+class StatePartitionReaderAllColumnFamiliesSuite extends 
StateDataSourceTestBase {
+
+  import testImplicits._
+
+  /**
+   * Returns a set of (partitionId, key, value) tuples from a normal state 
read.
+   */
+  private def getNormalReadData(checkpointDir: String): DataFrame = {
+    spark.read
+      .format("statestore")
+      .option(StateSourceOptions.PATH, checkpointDir)
+      .load()
+      .selectExpr("partition_id", "key", "value")
+  }
+
+  /**
+   * Returns a DataFrame with raw bytes mode 
(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES = true).
+   */
+  private def getBytesReadDf(checkpointDir: String): DataFrame = {
+    spark.read
+      .format("statestore")
+      .option(StateSourceOptions.PATH, checkpointDir)
+      .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, 
"true")
+      .load()
+  }
+
+  /**
+   * Validates the schema and column families of the bytes read DataFrame.
+   */
+  private def validateBytesReadSchema(df: DataFrame): Unit = {
+    // Verify schema
+    val schema = df.schema
+    assert(schema.fieldNames === Array(
+      "partition_key", "key_bytes", "value_bytes", "column_family_name", 
"value", "key"))
+    assert(schema("partition_key").dataType.typeName === "struct")
+    assert(schema("key_bytes").dataType.typeName === "binary")
+    assert(schema("value_bytes").dataType.typeName === "binary")
+    assert(schema("column_family_name").dataType.typeName === "string")
+  }
+
+  private def parseBytesReadData(df: Array[Row], keyLength: Int, valueLength: 
Int)

Review Comment:
   nit: `convertBytesToRow`?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala:
##########
@@ -0,0 +1,244 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state
+
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.catalyst.expressions.{GenericRowWithSchema, 
UnsafeRow}
+import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
+import 
org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, 
RocksDBStateStoreProvider}
+import org.apache.spark.sql.functions.{count, sum}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.{IntegerType, LongType, StructField, 
StructType}
+import org.apache.spark.tags.SlowSQLTest
+import org.apache.spark.unsafe.Platform
+
+/**
+ * Test suite to verify StatePartitionReaderAllColumnFamilies functionality.
+ */
+@SlowSQLTest
+class StatePartitionReaderAllColumnFamiliesSuite extends 
StateDataSourceTestBase {

Review Comment:
   Lets make the test cases in this suite run with both checkpoint v1 and v2



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala:
##########
@@ -0,0 +1,244 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state
+
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.catalyst.expressions.{GenericRowWithSchema, 
UnsafeRow}
+import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
+import 
org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, 
RocksDBStateStoreProvider}
+import org.apache.spark.sql.functions.{count, sum}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.{IntegerType, LongType, StructField, 
StructType}
+import org.apache.spark.tags.SlowSQLTest
+import org.apache.spark.unsafe.Platform
+
+/**
+ * Test suite to verify StatePartitionReaderAllColumnFamilies functionality.
+ */
+@SlowSQLTest
+class StatePartitionReaderAllColumnFamiliesSuite extends 
StateDataSourceTestBase {
+
+  import testImplicits._
+
+  /**
+   * Returns a set of (partitionId, key, value) tuples from a normal state 
read.
+   */
+  private def getNormalReadData(checkpointDir: String): DataFrame = {
+    spark.read
+      .format("statestore")
+      .option(StateSourceOptions.PATH, checkpointDir)
+      .load()
+      .selectExpr("partition_id", "key", "value")
+  }
+
+  /**
+   * Returns a DataFrame with raw bytes mode 
(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES = true).
+   */
+  private def getBytesReadDf(checkpointDir: String): DataFrame = {
+    spark.read
+      .format("statestore")
+      .option(StateSourceOptions.PATH, checkpointDir)
+      .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, 
"true")
+      .load()
+  }
+
+  /**
+   * Validates the schema and column families of the bytes read DataFrame.
+   */
+  private def validateBytesReadSchema(df: DataFrame): Unit = {
+    // Verify schema
+    val schema = df.schema
+    assert(schema.fieldNames === Array(
+      "partition_key", "key_bytes", "value_bytes", "column_family_name", 
"value", "key"))
+    assert(schema("partition_key").dataType.typeName === "struct")
+    assert(schema("key_bytes").dataType.typeName === "binary")
+    assert(schema("value_bytes").dataType.typeName === "binary")
+    assert(schema("column_family_name").dataType.typeName === "string")
+  }
+
+  private def parseBytesReadData(df: Array[Row], keyLength: Int, valueLength: 
Int)
+    : Set[(GenericRowWithSchema, UnsafeRow, UnsafeRow, String)] = {
+    df.map { row =>
+        val partitionKey = row.getAs[GenericRowWithSchema](0)
+        val keyBytes = row.getAs[Array[Byte]](1)
+        val valueBytes = row.getAs[Array[Byte]](2)
+        val columnFamily = row.getString(3)
+
+        // Deserialize key bytes to UnsafeRow
+        val keyRow = new UnsafeRow(keyLength)
+        keyRow.pointTo(
+          keyBytes,
+          Platform.BYTE_ARRAY_OFFSET,
+          keyBytes.length)
+
+        // Deserialize value bytes to UnsafeRow
+        val valueRow = new UnsafeRow(valueLength)
+        valueRow.pointTo(
+          valueBytes,
+          Platform.BYTE_ARRAY_OFFSET,
+          valueBytes.length)
+        (partitionKey, keyRow.copy(), valueRow.copy(), columnFamily)
+      }
+      .toSet
+  }
+
+  /**
+   * Compares normal read data with bytes read data for a specific column 
family.
+   */
+  private def compareNormalAndBytesData(
+      normalReadDf: DataFrame,
+      bytesReadDf: DataFrame,
+      columnFamily: String,
+      keySchema: StructType,
+      valueSchema: StructType): Unit = {
+    // Verify data
+    val bytesDf = bytesReadDf
+      .selectExpr("partition_key", "key_bytes", "value_bytes", 
"column_family_name")
+      .collect()
+    assert(bytesDf.length == 10,
+      s"Expected 10 rows but got: ${bytesDf.length}")
+
+    // Filter bytes data for the specified column family
+    val bytesData = parseBytesReadData(bytesDf, keySchema.length, 
valueSchema.length)
+    val filteredBytesData = bytesData.filter(_._4 == columnFamily)
+
+    // Apply the projection
+    // Convert to comparable format (extract field values)
+    val normalSet = normalReadDf.collect().map { row =>
+      val key = row.getStruct(1)
+      val value = row.getStruct(2)
+      val keyFields = (0 until key.length).map(i => key.get(i))
+      val valueFields = (0 until value.length).map(i => value.get(i))
+      (keyFields, valueFields)
+    }.toSet
+
+    val bytesSet = filteredBytesData.map { case (_, keyRow, valueRow, _) =>
+      val keyFields = (0 until keySchema.length).map(i =>
+        keyRow.get(i, keySchema(i).dataType))
+      val valueFields = (0 until valueSchema.length).map(i =>
+        valueRow.get(i, valueSchema(i).dataType))
+      (keyFields, valueFields)
+    }
+    // Verify same number of rows
+    assert(filteredBytesData.size == normalSet.size,
+      s"Row count mismatch for column family '$columnFamily': " +
+        s"normal read has ${filteredBytesData.size} rows, bytes read has 
${normalSet.size} rows")
+
+    assert(normalSet == bytesSet)
+  }
+
+    test(s"read all column families with simple operator") {
+      withTempDir { tempDir =>
+        withSQLConf(
+          SQLConf.STATE_STORE_PROVIDER_CLASS.key -> 
classOf[RocksDBStateStoreProvider].getName,

Review Comment:
   should we just set this at the class level, so that all test cases will use 
RocksDB. Except when a test case overrides the conf.



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala:
##########
@@ -0,0 +1,244 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state
+
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.catalyst.expressions.{GenericRowWithSchema, 
UnsafeRow}
+import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
+import 
org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, 
RocksDBStateStoreProvider}
+import org.apache.spark.sql.functions.{count, sum}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.{IntegerType, LongType, StructField, 
StructType}
+import org.apache.spark.tags.SlowSQLTest
+import org.apache.spark.unsafe.Platform
+
+/**
+ * Test suite to verify StatePartitionReaderAllColumnFamilies functionality.
+ */
+@SlowSQLTest
+class StatePartitionReaderAllColumnFamiliesSuite extends 
StateDataSourceTestBase {
+
+  import testImplicits._
+
+  /**
+   * Returns a set of (partitionId, key, value) tuples from a normal state 
read.
+   */
+  private def getNormalReadData(checkpointDir: String): DataFrame = {
+    spark.read
+      .format("statestore")
+      .option(StateSourceOptions.PATH, checkpointDir)
+      .load()
+      .selectExpr("partition_id", "key", "value")
+  }
+
+  /**
+   * Returns a DataFrame with raw bytes mode 
(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES = true).
+   */
+  private def getBytesReadDf(checkpointDir: String): DataFrame = {
+    spark.read
+      .format("statestore")
+      .option(StateSourceOptions.PATH, checkpointDir)
+      .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, 
"true")
+      .load()
+  }
+
+  /**
+   * Validates the schema and column families of the bytes read DataFrame.
+   */
+  private def validateBytesReadSchema(df: DataFrame): Unit = {
+    // Verify schema
+    val schema = df.schema
+    assert(schema.fieldNames === Array(
+      "partition_key", "key_bytes", "value_bytes", "column_family_name", 
"value", "key"))
+    assert(schema("partition_key").dataType.typeName === "struct")
+    assert(schema("key_bytes").dataType.typeName === "binary")
+    assert(schema("value_bytes").dataType.typeName === "binary")
+    assert(schema("column_family_name").dataType.typeName === "string")
+  }
+
+  private def parseBytesReadData(df: Array[Row], keyLength: Int, valueLength: 
Int)
+    : Set[(GenericRowWithSchema, UnsafeRow, UnsafeRow, String)] = {
+    df.map { row =>
+        val partitionKey = row.getAs[GenericRowWithSchema](0)
+        val keyBytes = row.getAs[Array[Byte]](1)
+        val valueBytes = row.getAs[Array[Byte]](2)
+        val columnFamily = row.getString(3)
+
+        // Deserialize key bytes to UnsafeRow
+        val keyRow = new UnsafeRow(keyLength)
+        keyRow.pointTo(
+          keyBytes,
+          Platform.BYTE_ARRAY_OFFSET,
+          keyBytes.length)
+
+        // Deserialize value bytes to UnsafeRow
+        val valueRow = new UnsafeRow(valueLength)
+        valueRow.pointTo(
+          valueBytes,
+          Platform.BYTE_ARRAY_OFFSET,
+          valueBytes.length)
+        (partitionKey, keyRow.copy(), valueRow.copy(), columnFamily)
+      }
+      .toSet
+  }
+
+  /**
+   * Compares normal read data with bytes read data for a specific column 
family.
+   */
+  private def compareNormalAndBytesData(
+      normalReadDf: DataFrame,
+      bytesReadDf: DataFrame,
+      columnFamily: String,
+      keySchema: StructType,
+      valueSchema: StructType): Unit = {
+    // Verify data
+    val bytesDf = bytesReadDf
+      .selectExpr("partition_key", "key_bytes", "value_bytes", 
"column_family_name")
+      .collect()
+    assert(bytesDf.length == 10,
+      s"Expected 10 rows but got: ${bytesDf.length}")
+
+    // Filter bytes data for the specified column family
+    val bytesData = parseBytesReadData(bytesDf, keySchema.length, 
valueSchema.length)
+    val filteredBytesData = bytesData.filter(_._4 == columnFamily)
+
+    // Apply the projection
+    // Convert to comparable format (extract field values)
+    val normalSet = normalReadDf.collect().map { row =>
+      val key = row.getStruct(1)
+      val value = row.getStruct(2)
+      val keyFields = (0 until key.length).map(i => key.get(i))
+      val valueFields = (0 until value.length).map(i => value.get(i))
+      (keyFields, valueFields)
+    }.toSet
+
+    val bytesSet = filteredBytesData.map { case (_, keyRow, valueRow, _) =>
+      val keyFields = (0 until keySchema.length).map(i =>
+        keyRow.get(i, keySchema(i).dataType))
+      val valueFields = (0 until valueSchema.length).map(i =>
+        valueRow.get(i, valueSchema(i).dataType))
+      (keyFields, valueFields)
+    }
+    // Verify same number of rows
+    assert(filteredBytesData.size == normalSet.size,

Review Comment:
   are you comparing the correct vals here?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala:
##########
@@ -60,6 +61,17 @@ object SchemaUtil {
         .add("key", keySchema)
         .add("value", valueSchema)
         .add("partition_id", IntegerType)
+    } else if (sourceOptions.internalOnlyReadAllColumnFamilies) {
+      new StructType()
+        // todo: change this to some more specific type after we

Review Comment:
   Add this ticket number `SPARK-54443` to the todo



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to