gardnervickers commented on a change in pull request #7929:
URL: https://github.com/apache/kafka/pull/7929#discussion_r501942512



##########
File path: core/src/main/scala/kafka/log/ProducerStateManager.scala
##########
@@ -653,36 +697,44 @@ class ProducerStateManager(val topicPartition: 
TopicPartition,
   def takeSnapshot(): Unit = {
     // If not a new offset, then it is not worth taking another snapshot
     if (lastMapOffset > lastSnapOffset) {
-      val snapshotFile = Log.producerSnapshotFile(logDir, lastMapOffset)
+      val snapshotFile = SnapshotFile(Log.producerSnapshotFile(_logDir, 
lastMapOffset))
       info(s"Writing producer snapshot at offset $lastMapOffset")
-      writeSnapshot(snapshotFile, producers)
+      writeSnapshot(snapshotFile.file, producers)
+      snapshots.put(snapshotFile.offset, snapshotFile)
 
       // Update the last snap offset according to the serialized map
       lastSnapOffset = lastMapOffset
     }
   }
 
+  /**
+   * Update the parentDir for this ProducerStateManager and all of the 
snapshot files which it manages.
+   */
+  def updateParentDir(parentDir: File): Unit ={
+    _logDir = parentDir
+    snapshots.forEach((_, s) => s.updateParentDir(parentDir))
+  }
+
   /**
    * Get the last offset (exclusive) of the latest snapshot file.
    */
-  def latestSnapshotOffset: Option[Long] = latestSnapshotFile.map(file => 
offsetFromFile(file))
+  def latestSnapshotOffset: Option[Long] = latestSnapshotFile.map(_.offset)
 
   /**
    * Get the last offset (exclusive) of the oldest snapshot file.
    */
-  def oldestSnapshotOffset: Option[Long] = oldestSnapshotFile.map(file => 
offsetFromFile(file))
+  def oldestSnapshotOffset: Option[Long] = oldestSnapshotFile.map(_.offset)
 
   /**
-   * When we remove the head of the log due to retention, we need to remove 
snapshots older than
-   * the new log start offset.
+   * Remove any unreplicated transactions lower than the provided 
logStartOffset and bring the lastMapOffset forward
+   * if necessary.
    */
-  def truncateHead(logStartOffset: Long): Unit = {
+  def onLogStartOffsetIncremented(logStartOffset: Long): Unit = {
     removeUnreplicatedTransactions(logStartOffset)
 
     if (lastMapOffset < logStartOffset)
       lastMapOffset = logStartOffset
 
-    deleteSnapshotsBefore(logStartOffset)

Review comment:
       The idea here is to clear un-replicated transactions and optionally 
advance the `lastMapOffset` and `lastSnapOffset` when the logStartOffset is 
advanced, but to leave the snapshot files around. The corresponding snapshot 
files should be removed during the retention pass as we cleanup the associated 
segment files. 
   
   I was attempting to optimize incrementing the logStartOffset a bit so that 
we don't need to delete the snapshot files from the request handler thread when 
handling `DELETE_RECORDS`. 

##########
File path: core/src/main/scala/kafka/log/ProducerStateManager.scala
##########
@@ -496,6 +491,53 @@ class ProducerStateManager(val topicPartition: 
TopicPartition,
   // completed transactions whose markers are at offsets above the high 
watermark
   private val unreplicatedTxns = new util.TreeMap[Long, TxnMetadata]
 
+  /**
+   * Load producer state snapshots by scanning the _logDir.
+   */
+  private def loadSnapshots(): ConcurrentSkipListMap[java.lang.Long, 
SnapshotFile] = {
+    val tm = new ConcurrentSkipListMap[java.lang.Long, SnapshotFile]()
+    for (f <- ProducerStateManager.listSnapshotFiles(_logDir)) {
+      tm.put(f.offset, f)
+    }
+    tm
+  }
+
+  /**
+   * Scans the log directory, gathering all producer state snapshot files. 
Snapshot files which do not have an offset
+   * corresponding to one of the provided offsets in segmentBaseOffsets will 
be removed, except in the case that there
+   * is a snapshot file at a higher offset than any offset in 
segmentBaseOffsets.
+   *
+   * The goal here is to remove any snapshot files which do not have an 
associated segment file, but not to remove
+   */
+  private[log] def removeStraySnapshots(segmentBaseOffsets: Set[Long]): Unit = 
{
+    var latestStraySnapshot: Option[SnapshotFile] = None
+    val ss = loadSnapshots()
+    for (snapshot <- ss.values().asScala) {
+      val key = snapshot.offset
+      latestStraySnapshot match {
+        case Some(prev) =>
+          if (!segmentBaseOffsets.contains(key)) {
+            // this snapshot is now the largest stray snapshot.
+            prev.deleteIfExists()
+            ss.remove(prev.offset)
+            latestStraySnapshot = Some(snapshot)
+          }
+        case None =>
+          if (!segmentBaseOffsets.contains(key)) {
+            latestStraySnapshot = Some(snapshot)

Review comment:
       We perform a check below which may cover this case. After setting the 
`snapshots` map, we look at the latest snapshot in the map. If the latest 
snapshot in the map is not equal to the `latestStraySnapshot`, we delete the 
`latestStraySnapshot`. 
   
   I think this is a bit confusing though, so it might be better if instead we 
directly check that the `latestStraySnapshot` is larger than the largest offset 
in `segmentBaseOffsets`. 

##########
File path: core/src/test/scala/unit/kafka/log/LogTest.scala
##########
@@ -1226,6 +1225,104 @@ class LogTest {
     assertEquals(retainedLastSeqOpt, reloadedLastSeqOpt)
   }
 
+  @Test
+  def testRetentionDeletesProducerStateSnapshots(): Unit = {
+    val logConfig = LogTest.createLogConfig(segmentBytes = 2048 * 5, 
retentionBytes = 0, retentionMs = 1000 * 60, fileDeleteDelayMs = 0)
+    val log = createLog(logDir, logConfig)
+    val pid1 = 1L
+    val epoch = 0.toShort
+
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes)), 
producerId = pid1,
+      producerEpoch = epoch, sequence = 0), leaderEpoch = 0)
+    log.roll()
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord("b".getBytes)), 
producerId = pid1,
+      producerEpoch = epoch, sequence = 1), leaderEpoch = 0)
+    log.roll()
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord("c".getBytes)), 
producerId = pid1,
+      producerEpoch = epoch, sequence = 2), leaderEpoch = 0)
+
+    log.updateHighWatermark(log.logEndOffset)
+
+    assertEquals(2, ProducerStateManager.listSnapshotFiles(logDir).size)
+    // Sleep to breach the retention period
+    mockTime.sleep(1000 * 60 + 1)
+    log.deleteOldSegments()
+    // Sleep to breach the file delete delay and run scheduled file deletion 
tasks
+    mockTime.sleep(1)
+    assertEquals("expect a single producer state snapshot remaining", 1, 
ProducerStateManager.listSnapshotFiles(logDir).size)
+  }
+
+  @Test
+  def testLogStartOffsetMovementDeletesSnapshots(): Unit = {
+    val logConfig = LogTest.createLogConfig(segmentBytes = 2048 * 5, 
retentionBytes = -1, fileDeleteDelayMs = 0)
+    val log = createLog(logDir, logConfig)
+    val pid1 = 1L
+    val epoch = 0.toShort
+
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes)), 
producerId = pid1,
+      producerEpoch = epoch, sequence = 0), leaderEpoch = 0)
+    log.roll()
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord("b".getBytes)), 
producerId = pid1,
+      producerEpoch = epoch, sequence = 1), leaderEpoch = 0)
+    log.roll()
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord("c".getBytes)), 
producerId = pid1,
+      producerEpoch = epoch, sequence = 2), leaderEpoch = 0)
+    log.updateHighWatermark(log.logEndOffset)
+    assertEquals(2, ProducerStateManager.listSnapshotFiles(logDir).size)
+
+    // Increment the log start offset to exclude the first two segments.
+    log.maybeIncrementLogStartOffset(log.logEndOffset - 1, 
ClientRecordDeletion)
+    log.deleteOldSegments()
+    // Sleep to breach the file delete delay and run scheduled file deletion 
tasks
+    mockTime.sleep(1)
+    assertEquals("expect a single producer state snapshot remaining", 1, 
ProducerStateManager.listSnapshotFiles(logDir).size)
+  }
+
+  @Test
+  def testCompactionDeletesProducerStateSnapshots(): Unit = {
+    val logConfig = LogTest.createLogConfig(segmentBytes = 2048 * 5, 
cleanupPolicy = LogConfig.Compact, fileDeleteDelayMs = 0)
+    val log = createLog(logDir, logConfig)
+    val pid1 = 1L
+    val epoch = 0.toShort
+    val cleaner = new Cleaner(id = 0,
+      offsetMap = new FakeOffsetMap(Int.MaxValue),
+      ioBufferSize = 64 * 1024,
+      maxIoBufferSize = 64 * 1024,
+      dupBufferLoadFactor = 0.75,
+      throttler = new Throttler(Double.MaxValue, Long.MaxValue, false, time = 
mockTime),
+      time = mockTime,
+      checkDone = _ => {})
+
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes, 
"a".getBytes())), producerId = pid1,
+      producerEpoch = epoch, sequence = 0), leaderEpoch = 0)
+    log.roll()
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes, 
"b".getBytes())), producerId = pid1,
+      producerEpoch = epoch, sequence = 1), leaderEpoch = 0)
+    log.roll()
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes, 
"c".getBytes())), producerId = pid1,
+      producerEpoch = epoch, sequence = 2), leaderEpoch = 0)
+    log.updateHighWatermark(log.logEndOffset)
+    assertEquals("expected a snapshot file per segment base offset, except the 
first segment", log.logSegments.map(_.baseOffset).toSeq.sorted.drop(1), 
ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted)
+    assertEquals(2, ProducerStateManager.listSnapshotFiles(logDir).size)
+
+    // Clean segments, this should delete everything except the active segment 
since there only
+    // exists the key "a".
+    cleaner.clean(LogToClean(log.topicPartition, log, 0, log.logEndOffset))
+    log.deleteOldSegments()
+    // Sleep to breach the file delete delay and run scheduled file deletion 
tasks
+    mockTime.sleep(1)
+    assertEquals("expected a snapshot file per segment base offset, excluding 
the first", log.logSegments.map(_.baseOffset).toSeq.sorted.drop(1), 
ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted)
+  }
+
+  @Test
+  def testLoadingLogCleansOrphanedProducerStateSnapshots(): Unit = {
+    val orphanedSnapshotFile = Log.producerSnapshotFile(logDir, 42).toPath
+    Files.createFile(orphanedSnapshotFile)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 2048 * 5, 
retentionBytes = -1, fileDeleteDelayMs = 0)
+    createLog(logDir, logConfig)
+    assertEquals("expected orphaned producer state snapshot file to be cleaned 
up", 0, ProducerStateManager.listSnapshotFiles(logDir).size)

Review comment:
       It's being deleted because during producer state loading because we 
truncate producer state to match the bounds of the log, and the snapshot file 
written out at offset 42 is higher than the log end offset of the empty log. 
The test name is not very clear in this case though. 
   
   I will fix the name and add another test which checks that we keep around 
the largest stray producer state snapshot file.

##########
File path: core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala
##########
@@ -834,6 +834,40 @@ class ProducerStateManagerTest {
     assertEquals(None, 
stateManager.lastEntry(producerId).get.currentTxnFirstOffset)
   }
 
+  @Test
+  def testRemoveStraySnapshotsKeepCleanShutdownSnapshot(): Unit = {
+    // Test that when stray snapshots are removed, the largest stray snapshot 
is kept around. This covers the case where
+    // the broker shutdown cleanly and emitted a snapshot file larger than the 
base offset of the active segment.
+
+    // Create 3 snapshot files at different offsets.
+    Log.producerSnapshotFile(logDir, 42).createNewFile()
+    Log.producerSnapshotFile(logDir, 5).createNewFile()
+    Log.producerSnapshotFile(logDir, 2).createNewFile()
+
+    // claim that we only have one segment with a base offset of 5
+    stateManager.removeStraySnapshots(Set(5))
+
+    // The snapshot file at offset 2 should be considered a stray, but the 
snapshot at 42 should be kept
+    // around because it is the largest snapshot.
+    assertEquals(Some(42), stateManager.latestSnapshotOffset)
+    assertEquals(Some(5), stateManager.oldestSnapshotOffset)
+    assertEquals(Seq(5, 42), 
ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted)
+  }
+
+  @Test
+  def testRemoveAllStraySnapshots(): Unit = {
+    // Test that when stray snapshots are removed, all stray snapshots are 
removed when the base offset of the largest
+    // segment exceeds the offset of the largest stray snapshot.

Review comment:
       I think this sentence is a bit confusing. Snapshot 42 is not meant to be 
a stray snapshot here, only 5 and 2 are. I will try to reword this. 

##########
File path: core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala
##########
@@ -834,6 +834,40 @@ class ProducerStateManagerTest {
     assertEquals(None, 
stateManager.lastEntry(producerId).get.currentTxnFirstOffset)
   }
 
+  @Test
+  def testRemoveStraySnapshotsKeepCleanShutdownSnapshot(): Unit = {
+    // Test that when stray snapshots are removed, the largest stray snapshot 
is kept around. This covers the case where
+    // the broker shutdown cleanly and emitted a snapshot file larger than the 
base offset of the active segment.
+
+    // Create 3 snapshot files at different offsets.
+    Log.producerSnapshotFile(logDir, 42).createNewFile()
+    Log.producerSnapshotFile(logDir, 5).createNewFile()
+    Log.producerSnapshotFile(logDir, 2).createNewFile()
+
+    // claim that we only have one segment with a base offset of 5
+    stateManager.removeStraySnapshots(Set(5))
+
+    // The snapshot file at offset 2 should be considered a stray, but the 
snapshot at 42 should be kept
+    // around because it is the largest snapshot.
+    assertEquals(Some(42), stateManager.latestSnapshotOffset)
+    assertEquals(Some(5), stateManager.oldestSnapshotOffset)
+    assertEquals(Seq(5, 42), 
ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted)
+  }
+
+  @Test
+  def testRemoveAllStraySnapshots(): Unit = {
+    // Test that when stray snapshots are removed, all stray snapshots are 
removed when the base offset of the largest
+    // segment exceeds the offset of the largest stray snapshot.

Review comment:
       Hmm, I think my comment here could be worded better. Offset `42` here is 
not a "stray", since we provide it along with the list of segmentBaseOffsets to 
`removeStraySnapshots`. 
   
   I'll change up the wording on this, thanks!

##########
File path: core/src/test/scala/unit/kafka/log/LogTest.scala
##########
@@ -1226,6 +1225,104 @@ class LogTest {
     assertEquals(retainedLastSeqOpt, reloadedLastSeqOpt)
   }
 
+  @Test
+  def testRetentionDeletesProducerStateSnapshots(): Unit = {
+    val logConfig = LogTest.createLogConfig(segmentBytes = 2048 * 5, 
retentionBytes = 0, retentionMs = 1000 * 60, fileDeleteDelayMs = 0)
+    val log = createLog(logDir, logConfig)
+    val pid1 = 1L
+    val epoch = 0.toShort
+
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes)), 
producerId = pid1,
+      producerEpoch = epoch, sequence = 0), leaderEpoch = 0)
+    log.roll()
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord("b".getBytes)), 
producerId = pid1,
+      producerEpoch = epoch, sequence = 1), leaderEpoch = 0)
+    log.roll()
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord("c".getBytes)), 
producerId = pid1,
+      producerEpoch = epoch, sequence = 2), leaderEpoch = 0)
+
+    log.updateHighWatermark(log.logEndOffset)
+
+    assertEquals(2, ProducerStateManager.listSnapshotFiles(logDir).size)
+    // Sleep to breach the retention period
+    mockTime.sleep(1000 * 60 + 1)
+    log.deleteOldSegments()
+    // Sleep to breach the file delete delay and run scheduled file deletion 
tasks
+    mockTime.sleep(1)
+    assertEquals("expect a single producer state snapshot remaining", 1, 
ProducerStateManager.listSnapshotFiles(logDir).size)
+  }
+
+  @Test
+  def testLogStartOffsetMovementDeletesSnapshots(): Unit = {
+    val logConfig = LogTest.createLogConfig(segmentBytes = 2048 * 5, 
retentionBytes = -1, fileDeleteDelayMs = 0)
+    val log = createLog(logDir, logConfig)
+    val pid1 = 1L
+    val epoch = 0.toShort
+
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes)), 
producerId = pid1,
+      producerEpoch = epoch, sequence = 0), leaderEpoch = 0)
+    log.roll()
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord("b".getBytes)), 
producerId = pid1,
+      producerEpoch = epoch, sequence = 1), leaderEpoch = 0)
+    log.roll()
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord("c".getBytes)), 
producerId = pid1,
+      producerEpoch = epoch, sequence = 2), leaderEpoch = 0)
+    log.updateHighWatermark(log.logEndOffset)
+    assertEquals(2, ProducerStateManager.listSnapshotFiles(logDir).size)
+
+    // Increment the log start offset to exclude the first two segments.
+    log.maybeIncrementLogStartOffset(log.logEndOffset - 1, 
ClientRecordDeletion)
+    log.deleteOldSegments()
+    // Sleep to breach the file delete delay and run scheduled file deletion 
tasks
+    mockTime.sleep(1)
+    assertEquals("expect a single producer state snapshot remaining", 1, 
ProducerStateManager.listSnapshotFiles(logDir).size)
+  }
+
+  @Test
+  def testCompactionDeletesProducerStateSnapshots(): Unit = {
+    val logConfig = LogTest.createLogConfig(segmentBytes = 2048 * 5, 
cleanupPolicy = LogConfig.Compact, fileDeleteDelayMs = 0)
+    val log = createLog(logDir, logConfig)
+    val pid1 = 1L
+    val epoch = 0.toShort
+    val cleaner = new Cleaner(id = 0,
+      offsetMap = new FakeOffsetMap(Int.MaxValue),
+      ioBufferSize = 64 * 1024,
+      maxIoBufferSize = 64 * 1024,
+      dupBufferLoadFactor = 0.75,
+      throttler = new Throttler(Double.MaxValue, Long.MaxValue, false, time = 
mockTime),
+      time = mockTime,
+      checkDone = _ => {})
+
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes, 
"a".getBytes())), producerId = pid1,
+      producerEpoch = epoch, sequence = 0), leaderEpoch = 0)
+    log.roll()
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes, 
"b".getBytes())), producerId = pid1,
+      producerEpoch = epoch, sequence = 1), leaderEpoch = 0)
+    log.roll()
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes, 
"c".getBytes())), producerId = pid1,
+      producerEpoch = epoch, sequence = 2), leaderEpoch = 0)
+    log.updateHighWatermark(log.logEndOffset)
+    assertEquals("expected a snapshot file per segment base offset, except the 
first segment", log.logSegments.map(_.baseOffset).toSeq.sorted.drop(1), 
ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted)
+    assertEquals(2, ProducerStateManager.listSnapshotFiles(logDir).size)
+
+    // Clean segments, this should delete everything except the active segment 
since there only
+    // exists the key "a".
+    cleaner.clean(LogToClean(log.topicPartition, log, 0, log.logEndOffset))
+    log.deleteOldSegments()
+    // Sleep to breach the file delete delay and run scheduled file deletion 
tasks
+    mockTime.sleep(1)
+    assertEquals("expected a snapshot file per segment base offset, excluding 
the first", log.logSegments.map(_.baseOffset).toSeq.sorted.drop(1), 
ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted)
+  }
+
+  @Test
+  def testLoadingLogCleansOrphanedProducerStateSnapshots(): Unit = {
+    val orphanedSnapshotFile = Log.producerSnapshotFile(logDir, 42).toPath
+    Files.createFile(orphanedSnapshotFile)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 2048 * 5, 
retentionBytes = -1, fileDeleteDelayMs = 0)
+    createLog(logDir, logConfig)
+    assertEquals("expected orphaned producer state snapshot file to be cleaned 
up", 0, ProducerStateManager.listSnapshotFiles(logDir).size)

Review comment:
       I added some extra context to the existing test and created a new test 
which verifies from the log that the largest stray snapshot which is within the 
logs end offset is retained 
`testLoadingLogKeepsLargestStrayProducerStateSnapshot`. 

##########
File path: core/src/test/scala/unit/kafka/log/LogTest.scala
##########
@@ -782,7 +782,7 @@ class LogTest {
     }
 
     // Retain snapshots for the last 2 segments
-    ProducerStateManager.deleteSnapshotsBefore(logDir, 
segmentOffsets(segmentOffsets.size - 2))
+    ProducerStateManager.listSnapshotFiles(logDir).filter(_.offset < 
segmentOffsets(segmentOffsets.size - 2)).foreach(_.deleteIfExists())

Review comment:
       Yes, it should work if we switch these back to using 
deleteSnapshotsBefore. Thanks!




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to