micheal-o commented on code in PR #53703:
URL: https://github.com/apache/spark/pull/53703#discussion_r2666825514


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala:
##########
@@ -51,136 +48,90 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
     spark.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "2")
   }
 
-  /**
-   * Helper method to create a StateSchemaProvider from column family schema 
map.
-   */
-  private def createStateSchemaProvider(
-      columnFamilyToSchemaMap: Map[String, 
StatePartitionWriterColumnFamilyInfo]
-  ): StateSchemaProvider = {
-    val testSchemaProvider = new TestStateSchemaProvider()
-    columnFamilyToSchemaMap.foreach { case (cfName, cfInfo) =>
-      testSchemaProvider.captureSchema(
-        colFamilyName = cfName,
-        keySchema = cfInfo.schema.keySchema,
-        valueSchema = cfInfo.schema.valueSchema,
-        keySchemaId = cfInfo.schema.keySchemaId,
-        valueSchemaId = cfInfo.schema.valueSchemaId
-      )
-    }
-    testSchemaProvider
-  }
-
   /**
    * Common helper method to perform round-trip test: read state bytes from 
source,
    * write to target, and verify target matches source.
    *
    * @param sourceDir Source checkpoint directory
    * @param targetDir Target checkpoint directory
-   * @param columnFamilyToSchemaMap Map of column family names to their schemas
-   * @param storeName Optional store name (for stream-stream join which has 
multiple stores)
-   * @param columnFamilyToSelectExprs Map of column family names to custom 
selectExprs
-   * @param columnFamilyToStateSourceOptions Map of column family names to 
state source options
+   * @param storeToColumnFamilies Optional store name to its column families
+   * @param storeToColumnFamilyToSelectExprs Map store name to per column 
family custom selectExprs
+   * @param storeToColumnFamilyToStateSourceOptions Map store name to per 
column family
+   *                                                state source options
    */
   private def performRoundTripTest(
       sourceDir: String,
       targetDir: String,
-      columnFamilyToSchemaMap: Map[String, 
StatePartitionWriterColumnFamilyInfo],
-      storeName: Option[String] = None,
-      columnFamilyToSelectExprs: Map[String, Seq[String]] = Map.empty,
-      columnFamilyToStateSourceOptions: Map[String, Map[String, String]] = 
Map.empty,
+      storeToColumnFamilies: Map[String, List[String]] =
+        Map(StateStoreId.DEFAULT_STORE_NAME -> 
List(StateStore.DEFAULT_COL_FAMILY_NAME)),
+      storeToColumnFamilyToSelectExprs: Map[String, Map[String, Seq[String]]] 
= Map.empty,
+      storeToColumnFamilyToStateSourceOptions: Map[String, Map[String, 
Map[String, String]]] =
+        Map.empty,
       operatorName: String): Unit = {
-
-    val columnFamiliesToValidate: Seq[String] = if 
(columnFamilyToSchemaMap.size > 1) {
-      columnFamilyToSchemaMap.keys.toSeq
-    } else {
-      Seq(StateStore.DEFAULT_COL_FAMILY_NAME)
-    }
-
-    // Step 1: Read from source using AllColumnFamiliesReader (raw bytes)
-    val sourceBytesReader = spark.read
-      .format("statestore")
-      .option(StateSourceOptions.PATH, sourceDir)
-      .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, 
"true")
-    val sourceBytesData = (storeName match {
-      case Some(name) => 
sourceBytesReader.option(StateSourceOptions.STORE_NAME, name)
-      case None => sourceBytesReader
-    }).load()
-
-    // Verify schema of raw bytes
-    val schema = sourceBytesData.schema
-    assert(schema.fieldNames === Array(
-      "partition_key", "key_bytes", "value_bytes", "column_family_name"))
-
-    // Step 2: Write raw bytes to target checkpoint location
     val hadoopConf = spark.sessionState.newHadoopConf()
+    val sourceCpLocation = StreamingUtils.resolvedCheckpointLocation(
+      hadoopConf, sourceDir)
+    val sourceCheckpointMetadata = new StreamingQueryCheckpointMetadata(
+      spark, sourceCpLocation)
+    val readBatchId = sourceCheckpointMetadata.commitLog.getLatestBatchId().get
+
     val targetCpLocation = StreamingUtils.resolvedCheckpointLocation(
       hadoopConf, targetDir)
     val targetCheckpointMetadata = new StreamingQueryCheckpointMetadata(
       spark, targetCpLocation)
     // increase offsetCheckpoint
     val lastBatch = targetCheckpointMetadata.commitLog.getLatestBatchId().get
     val targetOffsetSeq = targetCheckpointMetadata.offsetLog.get(lastBatch).get
-    val currentBatchId = lastBatch + 1
-    targetCheckpointMetadata.offsetLog.add(currentBatchId, targetOffsetSeq)
-
-    val storeConf: StateStoreConf = StateStoreConf(spark.sessionState.conf)
-    val serializableHadoopConf = new SerializableConfiguration(hadoopConf)
-
-    // Create StateSchemaProvider if needed (for Avro encoding)
-    val stateSchemaProvider = if (storeConf.stateStoreEncodingFormat == 
"avro") {
-      Some(createStateSchemaProvider(columnFamilyToSchemaMap))
-    } else {
-      None
-    }
-    val baseConfs: Map[String, String] = spark.sessionState.conf.getAllConfs
-    val putPartitionFunc: Iterator[InternalRow] => Unit = partition => {
-      val newConf = new SQLConf
-      baseConfs.foreach { case (k, v) =>
-        newConf.setConfString(k, v)
-      }
-      val allCFWriter = new StatePartitionAllColumnFamiliesWriter(
-        storeConf,
-        serializableHadoopConf.value,
-        TaskContext.getPartitionId(),
-        targetCpLocation,
-        0,
-        storeName.getOrElse(StateStoreId.DEFAULT_STORE_NAME),
-        currentBatchId,
-        columnFamilyToSchemaMap,
-        operatorName,
-        stateSchemaProvider,
-        newConf
-      )
-      allCFWriter.write(partition)
-    }
-    sourceBytesData.queryExecution.toRdd.foreachPartition(putPartitionFunc)
+    val writeBatchId = lastBatch + 1
+    targetCheckpointMetadata.offsetLog.add(writeBatchId, targetOffsetSeq)
+
+    val rewriter = new StateRewriter(
+      spark,
+      readBatchId,
+      writeBatchId,
+      targetCpLocation,
+      hadoopConf,
+      readResolvedCheckpointLocation = Some(sourceCpLocation),
+      transformFunc = None,

Review Comment:
   That will be added in my next PR when I integrate with repartition runner.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to