liviazhu commented on code in PR #56018:
URL: https://github.com/apache/spark/pull/56018#discussion_r3321294475


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala:
##########
@@ -294,7 +294,10 @@ class OfflineStateRepartitionRunner(
       lastCommittedBatchId: Long,
       opIdToStateStoreCkptInfo: Option[Map[Long, Array[Array[String]]]]): Unit 
= {
     val latestCommit = 
checkpointMetadata.commitLog.get(lastCommittedBatchId).get
-    val commitMetadata = latestCommit.copy(stateUniqueIds = 
opIdToStateStoreCkptInfo)
+    val commitMetadata = checkpointMetadata.commitLog.createMetadata(

Review Comment:
   QQ, why are we refactoring to use createMetadata here rather than .copy()? 
We may miss new optional fields when new metadata versions are introduced.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/CommitLog.scala:
##########
@@ -50,39 +51,99 @@ class CommitLog(
     sparkSession: SparkSession,
     path: String,
     readOnly: Boolean = false)
-  extends HDFSMetadataLog[CommitMetadata](sparkSession, path, readOnly) {
+  extends HDFSMetadataLog[CommitMetadataBase](sparkSession, path, readOnly) {
 
   import CommitLog._
 
-  private val VERSION: Int = sparkSession.conf.get(
+  // The configured commit log format version. Used as the default version 
when callers
+  // construct metadata through [[createMetadata]].
+  private[sql] val VERSION: Int = sparkSession.conf.get(
     SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key).toInt
 
-  override protected[sql] def deserialize(in: InputStream): CommitMetadata = {
-    // called inside a try-finally where the underlying stream is closed in 
the caller
-    val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines()
-    if (!lines.hasNext) {
-      throw new IllegalStateException("Incomplete log file in the offset 
commit log")
-    }
-    // TODO [SPARK-49462] This validation should be relaxed for a stateless 
query.
-    // TODO [SPARK-50653] This validation should be relaxed to support reading
-    //  a V1 log file when VERSION is V2
-    validateVersionExactMatch(lines.next().trim, VERSION)
-    val metadataJson = if (lines.hasNext) lines.next() else EMPTY_JSON
-    CommitMetadata(metadataJson)
+  override protected[sql] def deserialize(in: InputStream): CommitMetadataBase 
= {
+    CommitLog.readCommitMetadata(in)
   }
 
-  override protected[sql] def serialize(metadata: CommitMetadata, out: 
OutputStream): Unit = {
+  override protected[sql] def serialize(metadata: CommitMetadataBase, out: 
OutputStream): Unit = {
     // called inside a try-finally where the underlying stream is closed in 
the caller
-    out.write(s"v${VERSION}".getBytes(UTF_8))
+    out.write(s"v${metadata.version}".getBytes(UTF_8))
     out.write('\n')
 
     // write metadata
     out.write(metadata.json.getBytes(UTF_8))
   }
+
+  /**
+   * Factory for creating a [[CommitMetadataBase]] for the requested wire 
format version.
+   * Defaults to the version configured via 
[[SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION]].
+   */
+  def createMetadata(
+      nextBatchWatermarkMs: Long = 0,
+      stateUniqueIds: Option[Map[Long, Array[Array[String]]]] = None,
+      commitLogFormatVersion: Int = VERSION): CommitMetadataBase = {
+    commitLogFormatVersion match {
+      case VERSION_2 =>
+        CommitMetadataV2(nextBatchWatermarkMs, stateUniqueIds)
+      case VERSION_1 =>
+        CommitMetadata(nextBatchWatermarkMs)

Review Comment:
   Should we error or warn if stateUniqueIds are passed in but version is 1?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/CommitLog.scala:
##########
@@ -50,39 +51,99 @@ class CommitLog(
     sparkSession: SparkSession,
     path: String,
     readOnly: Boolean = false)
-  extends HDFSMetadataLog[CommitMetadata](sparkSession, path, readOnly) {
+  extends HDFSMetadataLog[CommitMetadataBase](sparkSession, path, readOnly) {
 
   import CommitLog._
 
-  private val VERSION: Int = sparkSession.conf.get(
+  // The configured commit log format version. Used as the default version 
when callers
+  // construct metadata through [[createMetadata]].
+  private[sql] val VERSION: Int = sparkSession.conf.get(

Review Comment:
   Can we rename VERSION to defaultVersion as it is no longer static



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala:
##########
@@ -600,37 +598,6 @@ class RocksDBWithCheckpointV2StateDataSourceReaderSuite 
extends StateDataSourceR
       "true")
   }
 
-  // TODO: Remove this test once we allow migrations from checkpoint v1 to v2
-  test("reading checkpoint v2 store with version 1 should fail") {

Review Comment:
   Should we flip the test to succeed rather than removing it?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala:
##########
@@ -600,37 +598,6 @@ class RocksDBWithCheckpointV2StateDataSourceReaderSuite 
extends StateDataSourceR
       "true")
   }
 
-  // TODO: Remove this test once we allow migrations from checkpoint v1 to v2
-  test("reading checkpoint v2 store with version 1 should fail") {

Review Comment:
   Also add a test for the reverse (reading v2 metadata when v1 is configured)



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/CommitLog.scala:
##########
@@ -104,19 +165,19 @@ object CommitLog {
  *          +--- ......
  * In the commit log, in addition to nextBatchWatermarkMs, we also store the 
unique ids of the
  * state store files.
+ *
  * @param nextBatchWatermarkMs The watermark of the next batch.
  * @param stateUniqueIds Map[Long, Array[Array[String]]] of map
  *                       OperatorId -> (partitionID -> array of uniqueID)
  */
-
-case class CommitMetadata(
+case class CommitMetadataV2(
     nextBatchWatermarkMs: Long = 0,
-    stateUniqueIds: Option[Map[Long, Array[Array[String]]]] = None) {
-  def json: String = Serialization.write(this)(CommitMetadata.format)
+    stateUniqueIds: Option[Map[Long, Array[Array[String]]]] = None) extends 
CommitMetadataBase {
+  override def version: Int = CommitLog.VERSION_2
 }
 
-object CommitMetadata {
+object CommitMetadataV2 {
   implicit val format: Formats = Serialization.formats(NoTypeHints)

Review Comment:
   Should we remove this and just use CommitMetadata.format instead?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/CommitLog.scala:
##########
@@ -50,39 +51,99 @@ class CommitLog(
     sparkSession: SparkSession,
     path: String,
     readOnly: Boolean = false)
-  extends HDFSMetadataLog[CommitMetadata](sparkSession, path, readOnly) {
+  extends HDFSMetadataLog[CommitMetadataBase](sparkSession, path, readOnly) {
 
   import CommitLog._
 
-  private val VERSION: Int = sparkSession.conf.get(
+  // The configured commit log format version. Used as the default version 
when callers
+  // construct metadata through [[createMetadata]].
+  private[sql] val VERSION: Int = sparkSession.conf.get(
     SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key).toInt
 
-  override protected[sql] def deserialize(in: InputStream): CommitMetadata = {
-    // called inside a try-finally where the underlying stream is closed in 
the caller
-    val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines()
-    if (!lines.hasNext) {
-      throw new IllegalStateException("Incomplete log file in the offset 
commit log")
-    }
-    // TODO [SPARK-49462] This validation should be relaxed for a stateless 
query.
-    // TODO [SPARK-50653] This validation should be relaxed to support reading
-    //  a V1 log file when VERSION is V2
-    validateVersionExactMatch(lines.next().trim, VERSION)
-    val metadataJson = if (lines.hasNext) lines.next() else EMPTY_JSON
-    CommitMetadata(metadataJson)
+  override protected[sql] def deserialize(in: InputStream): CommitMetadataBase 
= {
+    CommitLog.readCommitMetadata(in)
   }
 
-  override protected[sql] def serialize(metadata: CommitMetadata, out: 
OutputStream): Unit = {
+  override protected[sql] def serialize(metadata: CommitMetadataBase, out: 
OutputStream): Unit = {
     // called inside a try-finally where the underlying stream is closed in 
the caller
-    out.write(s"v${VERSION}".getBytes(UTF_8))
+    out.write(s"v${metadata.version}".getBytes(UTF_8))
     out.write('\n')
 
     // write metadata
     out.write(metadata.json.getBytes(UTF_8))
   }
+
+  /**
+   * Factory for creating a [[CommitMetadataBase]] for the requested wire 
format version.
+   * Defaults to the version configured via 
[[SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION]].
+   */
+  def createMetadata(
+      nextBatchWatermarkMs: Long = 0,
+      stateUniqueIds: Option[Map[Long, Array[Array[String]]]] = None,
+      commitLogFormatVersion: Int = VERSION): CommitMetadataBase = {
+    commitLogFormatVersion match {
+      case VERSION_2 =>
+        CommitMetadataV2(nextBatchWatermarkMs, stateUniqueIds)
+      case VERSION_1 =>
+        CommitMetadata(nextBatchWatermarkMs)

Review Comment:
   Also add a test for this



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to