zifeif2 commented on code in PR #53720:
URL: https://github.com/apache/spark/pull/53720#discussion_r2684474214
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala:
##########
@@ -312,8 +323,21 @@ object OfflineStateRepartitionTestUtils {
verifyOffsetAndCommitLog(
batchId, previousBatchId, expectedShufflePartitions, checkpointMetadata)
verifyPartitionDirs(checkpointLocation, expectedShufflePartitions)
+
+ val serializableConf = new SerializableConfiguration(hadoopConf)
+ val baseOperatorsMetadata = getOperatorMetadata(
+ checkpointLocation, serializableConf, previousBatchId)
+ val repartitionOperatorsMetadata = getOperatorMetadata(
Review Comment:
move the construction of operatorsMetadata out of verifyOperatorMetadata
because we need to get the number of state store to verifyCheckpointIds
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala:
##########
@@ -81,24 +82,46 @@ class StateRewriter(
private val stateRootLocation = new Path(
resolvedCheckpointLocation,
StreamingCheckpointConstants.DIR_NAME_STATE).toString
- def run(): Unit = {
+ // return a Map[operator id, Array[partition -> Array[stateStore ->
StateStoreCheckpointInfo]]]
+ def run(): Map[Long, Array[Array[StateStoreCheckpointInfo]]] = {
logInfo(log"Starting state rewrite for " +
log"checkpointLocation=${MDC(CHECKPOINT_LOCATION,
resolvedCheckpointLocation)}, " +
log"readCheckpointLocation=" +
log"${MDC(CHECKPOINT_LOCATION,
readResolvedCheckpointLocation.getOrElse(""))}, " +
log"readBatchId=${MDC(BATCH_ID, readBatchId)}, " +
log"writeBatchId=${MDC(BATCH_ID, writeBatchId)}")
- val (_, timeTakenMs) = Utils.timeTakenMs {
+ val (checkpointInfos, timeTakenMs) = Utils.timeTakenMs {
runInternal()
}
+ val allCheckpointInfos = checkpointInfos.values.flatMap { checkpointInfos
=>
+ checkpointInfos.flatten.toSeq
Review Comment:
I tried printing out checkpointInfos as is, but it print out something like:
```
Map(0 ->
[[Lorg.apache.spark.sql.execution.streaming.state.StateStoreCheckpointInfo;@28d6bdea))
```
while if we flatten it and print, it will can printout the detailed info
```
List(StateStoreCheckpointInfo(0,3,Some(3b36fc9a-128e-4aae-a9ea-dc743ba87c98),None),
StateStoreCheckpointInfo(1,3,Some(2537e690-860a-4662-a526-3ea7346a01b4),None))
```
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala:
##########
@@ -289,12 +290,27 @@ class OfflineStateRepartitionRunner(
}
}
- private def commitBatch(newBatchId: Long, lastCommittedBatchId: Long): Unit
= {
+ private def commitBatch(
+ newBatchId: Long,
+ lastCommittedBatchId: Long,
+ opIdToStateStoreCkptInfo: Map[Long,
Array[Array[StateStoreCheckpointInfo]]]): Unit = {
+ val enableCheckpointId = StatefulOperatorStateInfo.
+ enableStateStoreCheckpointIds(sparkSession.sessionState.conf)
val latestCommit =
checkpointMetadata.commitLog.get(lastCommittedBatchId).get
+ val commitMetadata = if (enableCheckpointId) {
+ val opIdToPartitionCkptInfo: Map[Long, Array[Array[String]]] =
+ opIdToStateStoreCkptInfo.map {
+ case (operator, partitionSeq) =>
+ operator -> partitionSeq.map(storeSeq =>
+ storeSeq.flatMap(info => info.stateStoreCkptId))
Review Comment:
I see. I was thinking we might need other fields in StateCheckpointInfo from
StateRewriter in future so I kept StateRewriter returning the entire
StateStoreCkptInfo instead of just the checkpoint ids. Down to just let
StateRewriter return StateStoreCkptId though!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]