zifeif2 commented on code in PR #53104: URL: https://github.com/apache/spark/pull/53104#discussion_r2557243370
########## sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala: ########## @@ -0,0 +1,449 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.state + +import java.nio.ByteOrder +import java.util.Arrays + +import org.apache.spark.SparkUnsupportedOperationException +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.execution.streaming.runtime.MemoryStream +import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider} +import org.apache.spark.sql.functions.{count, sum} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.{IntegerType, LongType, NullType, StructField, StructType} + +/** + * Note: This extends StateDataSourceTestBase to access + * helper methods like runDropDuplicatesQuery without inheriting all predefined tests. + */ +class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase { + + import testImplicits._ + + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key, + classOf[RocksDBStateStoreProvider].getName) + } + + private def getNormalReadDf(checkpointDir: String): DataFrame = { + spark.read + .format("statestore") + .option(StateSourceOptions.PATH, checkpointDir) + .load() + .selectExpr("partition_id", "key", "value") + } + + private def getBytesReadDf(checkpointDir: String): DataFrame = { + spark.read + .format("statestore") + .option(StateSourceOptions.PATH, checkpointDir) + .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, "true") + .load() + .selectExpr("partition_key", "key_bytes", "value_bytes", "column_family_name") + } + + /** + * Validates the schema and column families of the bytes read DataFrame. + */ + private def validateBytesReadDfSchema(df: DataFrame): Unit = { + // Verify schema + val schema = df.schema + assert(schema.fieldNames === Array( + "partition_key", "key_bytes", "value_bytes", "column_family_name")) + assert(schema("partition_key").dataType.typeName === "struct") + assert(schema("key_bytes").dataType.typeName === "binary") + assert(schema("value_bytes").dataType.typeName === "binary") + assert(schema("column_family_name").dataType.typeName === "string") + } + + /** + * Compares normal read data with bytes read data for a specific column family. + * Converts normal rows to bytes then compares with bytes read. + */ + private def compareNormalAndBytesData( + normalDf: Array[Row], + bytesDf: Array[Row], + columnFamily: String, + keySchema: StructType, + valueSchema: StructType): Unit = { + + // Filter bytes data for the specified column family and extract raw bytes directly + val filteredBytesData = bytesDf.filter { row => + row.getString(3) == columnFamily + } + + // Verify same number of rows + assert(filteredBytesData.length == normalDf.length, + s"Row count mismatch for column family '$columnFamily': " + + s"normal read has ${normalDf.length} rows, " + + s"bytes read has ${filteredBytesData.length} rows") + + // Create projections to convert Row to UnsafeRow bytes + val keyProjection = UnsafeProjection.create(keySchema) + val valueProjection = UnsafeProjection.create(valueSchema) + + // Create converters to convert external Row types to internal Catalyst types + val keyConverter = CatalystTypeConverters.createToCatalystConverter(keySchema) + val valueConverter = CatalystTypeConverters.createToCatalystConverter(valueSchema) + + // Convert normal data to bytes + val normalAsBytes = normalDf.map { row => + val key = row.getStruct(1) + val value = if (row.isNullAt(2)) null else row.getStruct(2) + + // Convert key to InternalRow, then to UnsafeRow, then get bytes + val keyInternalRow = keyConverter(key).asInstanceOf[InternalRow] + val keyUnsafeRow = keyProjection(keyInternalRow) + // IMPORTANT: Must clone the bytes array since getBytes() returns a reference + // that may be overwritten by subsequent UnsafeRow operations + val keyBytes = keyUnsafeRow.getBytes.clone() + + // Convert value to bytes + val valueBytes = if (value == null) { + Array.empty[Byte] + } else { + val valueInternalRow = valueConverter(value).asInstanceOf[InternalRow] + val valueUnsafeRow = valueProjection(valueInternalRow) + // IMPORTANT: Must clone the bytes array since getBytes() returns a reference + // that may be overwritten by subsequent UnsafeRow operations + valueUnsafeRow.getBytes.clone() + } + + (keyBytes, valueBytes) + } + + // Extract raw bytes from bytes read data (no deserialization/reserialization) + val bytesAsBytes = filteredBytesData.map { row => + val keyBytes = row.getAs[Array[Byte]](1) + val valueBytes = row.getAs[Array[Byte]](2) + (keyBytes, valueBytes) + } + + // Sort both for comparison (since Set equality doesn't work well with byte arrays) + val normalSorted = normalAsBytes.sortBy(x => (x._1.mkString(","), x._2.mkString(","))) + val bytesSorted = bytesAsBytes.sortBy(x => (x._1.mkString(","), x._2.mkString(","))) + + assert(normalSorted.length == bytesSorted.length, + s"Size mismatch: normal has ${normalSorted.length}, bytes has ${bytesSorted.length}") + + // Compare each pair + normalSorted.zip(bytesSorted).zipWithIndex.foreach { + case (((normalKey, normalValue), (bytesKey, bytesValue)), idx) => + assert(Arrays.equals(normalKey, bytesKey), + s"Key mismatch at index $idx:\n" + + s" Normal: ${normalKey.mkString("[", ",", "]")}\n" + + s" Bytes: ${bytesKey.mkString("[", ",", "]")}") + assert(Arrays.equals(normalValue, bytesValue), + s"Value mismatch at index $idx:\n" + + s" Normal: ${normalValue.mkString("[", ",", "]")}\n" + + s" Bytes: ${bytesValue.mkString("[", ",", "]")}") + } + } + + // Run all tests with both changelog checkpointing enabled and disabled + Seq(true, false).foreach { changelogCheckpointingEnabled => + val testSuffix = if (changelogCheckpointingEnabled) { + "with changelog checkpointing" + } else { + "without changelog checkpointing" + } + + def testWithChangelogConfig(testName: String)(testFun: => Unit): Unit = { + test(s"$testName ($testSuffix)") { + withSQLConf( + "spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled" -> + changelogCheckpointingEnabled.toString) { + testFun + } + } + } + + testWithChangelogConfig("all-column-families: simple aggregation state ver 1") { + withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "1") { + withTempDir { tempDir => + runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array(StructField("groupKey", IntegerType, nullable = false))) + // State version 1 includes key columns in the value + val valueSchema = StructType(Array( + StructField("groupKey", IntegerType, nullable = false), + StructField("count", LongType, nullable = false), + StructField("sum", LongType, nullable = false), + StructField("max", IntegerType, nullable = false), + StructField("min", IntegerType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) + + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema) + } + } + } + + testWithChangelogConfig("all-column-families: simple aggregation state ver 2") { + withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "2") { + withTempDir { tempDir => + runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array(StructField("groupKey", IntegerType, nullable = false))) + val valueSchema = StructType(Array( + StructField("count", LongType, nullable = false), + StructField("sum", LongType, nullable = false), + StructField("max", IntegerType, nullable = false), + StructField("min", IntegerType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() + + compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) + } + } + } + + testWithChangelogConfig("all-column-families: composite key aggregation state ver 1") { + withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "1") { + withTempDir { tempDir => + runCompositeKeyStreamingAggregationQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array( + StructField("groupKey", IntegerType, nullable = false), + StructField("fruit", org.apache.spark.sql.types.StringType, nullable = true) + )) + // State version 1 includes key columns in the value + val valueSchema = StructType(Array( + StructField("groupKey", IntegerType, nullable = false), + StructField("fruit", org.apache.spark.sql.types.StringType, nullable = true), + StructField("count", LongType, nullable = false), + StructField("sum", LongType, nullable = false), + StructField("max", IntegerType, nullable = false), + StructField("min", IntegerType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() + + compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) + } + } + } + + testWithChangelogConfig("all-column-families: composite key aggregation state ver 2") { + withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "2") { + withTempDir { tempDir => + runCompositeKeyStreamingAggregationQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array( + StructField("groupKey", IntegerType, nullable = false), + StructField("fruit", org.apache.spark.sql.types.StringType, nullable = true) + )) + val valueSchema = StructType(Array( + StructField("count", LongType, nullable = false), + StructField("sum", LongType, nullable = false), + StructField("max", IntegerType, nullable = false), + StructField("min", IntegerType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() + + compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) + } + } + } + + testWithChangelogConfig("all-column-families: dropDuplicates validation") { + withTempDir { tempDir => + runDropDuplicatesQuery(tempDir.getAbsolutePath) Review Comment: Will fix indentation! -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
