gardnervickers commented on a change in pull request #7929: URL: https://github.com/apache/kafka/pull/7929#discussion_r501942512
########## File path: core/src/main/scala/kafka/log/ProducerStateManager.scala ########## @@ -653,36 +697,44 @@ class ProducerStateManager(val topicPartition: TopicPartition, def takeSnapshot(): Unit = { // If not a new offset, then it is not worth taking another snapshot if (lastMapOffset > lastSnapOffset) { - val snapshotFile = Log.producerSnapshotFile(logDir, lastMapOffset) + val snapshotFile = SnapshotFile(Log.producerSnapshotFile(_logDir, lastMapOffset)) info(s"Writing producer snapshot at offset $lastMapOffset") - writeSnapshot(snapshotFile, producers) + writeSnapshot(snapshotFile.file, producers) + snapshots.put(snapshotFile.offset, snapshotFile) // Update the last snap offset according to the serialized map lastSnapOffset = lastMapOffset } } + /** + * Update the parentDir for this ProducerStateManager and all of the snapshot files which it manages. + */ + def updateParentDir(parentDir: File): Unit ={ + _logDir = parentDir + snapshots.forEach((_, s) => s.updateParentDir(parentDir)) + } + /** * Get the last offset (exclusive) of the latest snapshot file. */ - def latestSnapshotOffset: Option[Long] = latestSnapshotFile.map(file => offsetFromFile(file)) + def latestSnapshotOffset: Option[Long] = latestSnapshotFile.map(_.offset) /** * Get the last offset (exclusive) of the oldest snapshot file. */ - def oldestSnapshotOffset: Option[Long] = oldestSnapshotFile.map(file => offsetFromFile(file)) + def oldestSnapshotOffset: Option[Long] = oldestSnapshotFile.map(_.offset) /** - * When we remove the head of the log due to retention, we need to remove snapshots older than - * the new log start offset. + * Remove any unreplicated transactions lower than the provided logStartOffset and bring the lastMapOffset forward + * if necessary. */ - def truncateHead(logStartOffset: Long): Unit = { + def onLogStartOffsetIncremented(logStartOffset: Long): Unit = { removeUnreplicatedTransactions(logStartOffset) if (lastMapOffset < logStartOffset) lastMapOffset = logStartOffset - deleteSnapshotsBefore(logStartOffset) Review comment: The idea here is to clear un-replicated transactions and optionally advance the `lastMapOffset` and `lastSnapOffset` when the logStartOffset is advanced, but to leave the snapshot files around. The corresponding snapshot files should be removed during the retention pass as we cleanup the associated segment files. I was attempting to optimize incrementing the logStartOffset a bit so that we don't need to delete the snapshot files from the request handler thread when handling `DELETE_RECORDS`. ########## File path: core/src/main/scala/kafka/log/ProducerStateManager.scala ########## @@ -496,6 +491,53 @@ class ProducerStateManager(val topicPartition: TopicPartition, // completed transactions whose markers are at offsets above the high watermark private val unreplicatedTxns = new util.TreeMap[Long, TxnMetadata] + /** + * Load producer state snapshots by scanning the _logDir. + */ + private def loadSnapshots(): ConcurrentSkipListMap[java.lang.Long, SnapshotFile] = { + val tm = new ConcurrentSkipListMap[java.lang.Long, SnapshotFile]() + for (f <- ProducerStateManager.listSnapshotFiles(_logDir)) { + tm.put(f.offset, f) + } + tm + } + + /** + * Scans the log directory, gathering all producer state snapshot files. Snapshot files which do not have an offset + * corresponding to one of the provided offsets in segmentBaseOffsets will be removed, except in the case that there + * is a snapshot file at a higher offset than any offset in segmentBaseOffsets. + * + * The goal here is to remove any snapshot files which do not have an associated segment file, but not to remove + */ + private[log] def removeStraySnapshots(segmentBaseOffsets: Set[Long]): Unit = { + var latestStraySnapshot: Option[SnapshotFile] = None + val ss = loadSnapshots() + for (snapshot <- ss.values().asScala) { + val key = snapshot.offset + latestStraySnapshot match { + case Some(prev) => + if (!segmentBaseOffsets.contains(key)) { + // this snapshot is now the largest stray snapshot. + prev.deleteIfExists() + ss.remove(prev.offset) + latestStraySnapshot = Some(snapshot) + } + case None => + if (!segmentBaseOffsets.contains(key)) { + latestStraySnapshot = Some(snapshot) Review comment: We perform a check below which may cover this case. After setting the `snapshots` map, we look at the latest snapshot in the map. If the latest snapshot in the map is not equal to the `latestStraySnapshot`, we delete the `latestStraySnapshot`. I think this is a bit confusing though, so it might be better if instead we directly check that the `latestStraySnapshot` is larger than the largest offset in `segmentBaseOffsets`. ########## File path: core/src/test/scala/unit/kafka/log/LogTest.scala ########## @@ -1226,6 +1225,104 @@ class LogTest { assertEquals(retainedLastSeqOpt, reloadedLastSeqOpt) } + @Test + def testRetentionDeletesProducerStateSnapshots(): Unit = { + val logConfig = LogTest.createLogConfig(segmentBytes = 2048 * 5, retentionBytes = 0, retentionMs = 1000 * 60, fileDeleteDelayMs = 0) + val log = createLog(logDir, logConfig) + val pid1 = 1L + val epoch = 0.toShort + + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes)), producerId = pid1, + producerEpoch = epoch, sequence = 0), leaderEpoch = 0) + log.roll() + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("b".getBytes)), producerId = pid1, + producerEpoch = epoch, sequence = 1), leaderEpoch = 0) + log.roll() + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("c".getBytes)), producerId = pid1, + producerEpoch = epoch, sequence = 2), leaderEpoch = 0) + + log.updateHighWatermark(log.logEndOffset) + + assertEquals(2, ProducerStateManager.listSnapshotFiles(logDir).size) + // Sleep to breach the retention period + mockTime.sleep(1000 * 60 + 1) + log.deleteOldSegments() + // Sleep to breach the file delete delay and run scheduled file deletion tasks + mockTime.sleep(1) + assertEquals("expect a single producer state snapshot remaining", 1, ProducerStateManager.listSnapshotFiles(logDir).size) + } + + @Test + def testLogStartOffsetMovementDeletesSnapshots(): Unit = { + val logConfig = LogTest.createLogConfig(segmentBytes = 2048 * 5, retentionBytes = -1, fileDeleteDelayMs = 0) + val log = createLog(logDir, logConfig) + val pid1 = 1L + val epoch = 0.toShort + + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes)), producerId = pid1, + producerEpoch = epoch, sequence = 0), leaderEpoch = 0) + log.roll() + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("b".getBytes)), producerId = pid1, + producerEpoch = epoch, sequence = 1), leaderEpoch = 0) + log.roll() + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("c".getBytes)), producerId = pid1, + producerEpoch = epoch, sequence = 2), leaderEpoch = 0) + log.updateHighWatermark(log.logEndOffset) + assertEquals(2, ProducerStateManager.listSnapshotFiles(logDir).size) + + // Increment the log start offset to exclude the first two segments. + log.maybeIncrementLogStartOffset(log.logEndOffset - 1, ClientRecordDeletion) + log.deleteOldSegments() + // Sleep to breach the file delete delay and run scheduled file deletion tasks + mockTime.sleep(1) + assertEquals("expect a single producer state snapshot remaining", 1, ProducerStateManager.listSnapshotFiles(logDir).size) + } + + @Test + def testCompactionDeletesProducerStateSnapshots(): Unit = { + val logConfig = LogTest.createLogConfig(segmentBytes = 2048 * 5, cleanupPolicy = LogConfig.Compact, fileDeleteDelayMs = 0) + val log = createLog(logDir, logConfig) + val pid1 = 1L + val epoch = 0.toShort + val cleaner = new Cleaner(id = 0, + offsetMap = new FakeOffsetMap(Int.MaxValue), + ioBufferSize = 64 * 1024, + maxIoBufferSize = 64 * 1024, + dupBufferLoadFactor = 0.75, + throttler = new Throttler(Double.MaxValue, Long.MaxValue, false, time = mockTime), + time = mockTime, + checkDone = _ => {}) + + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes, "a".getBytes())), producerId = pid1, + producerEpoch = epoch, sequence = 0), leaderEpoch = 0) + log.roll() + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes, "b".getBytes())), producerId = pid1, + producerEpoch = epoch, sequence = 1), leaderEpoch = 0) + log.roll() + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes, "c".getBytes())), producerId = pid1, + producerEpoch = epoch, sequence = 2), leaderEpoch = 0) + log.updateHighWatermark(log.logEndOffset) + assertEquals("expected a snapshot file per segment base offset, except the first segment", log.logSegments.map(_.baseOffset).toSeq.sorted.drop(1), ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted) + assertEquals(2, ProducerStateManager.listSnapshotFiles(logDir).size) + + // Clean segments, this should delete everything except the active segment since there only + // exists the key "a". + cleaner.clean(LogToClean(log.topicPartition, log, 0, log.logEndOffset)) + log.deleteOldSegments() + // Sleep to breach the file delete delay and run scheduled file deletion tasks + mockTime.sleep(1) + assertEquals("expected a snapshot file per segment base offset, excluding the first", log.logSegments.map(_.baseOffset).toSeq.sorted.drop(1), ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted) + } + + @Test + def testLoadingLogCleansOrphanedProducerStateSnapshots(): Unit = { + val orphanedSnapshotFile = Log.producerSnapshotFile(logDir, 42).toPath + Files.createFile(orphanedSnapshotFile) + val logConfig = LogTest.createLogConfig(segmentBytes = 2048 * 5, retentionBytes = -1, fileDeleteDelayMs = 0) + createLog(logDir, logConfig) + assertEquals("expected orphaned producer state snapshot file to be cleaned up", 0, ProducerStateManager.listSnapshotFiles(logDir).size) Review comment: It's being deleted because during producer state loading because we truncate producer state to match the bounds of the log, and the snapshot file written out at offset 42 is higher than the log end offset of the empty log. The test name is not very clear in this case though. I will fix the name and add another test which checks that we keep around the largest stray producer state snapshot file. ########## File path: core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala ########## @@ -834,6 +834,40 @@ class ProducerStateManagerTest { assertEquals(None, stateManager.lastEntry(producerId).get.currentTxnFirstOffset) } + @Test + def testRemoveStraySnapshotsKeepCleanShutdownSnapshot(): Unit = { + // Test that when stray snapshots are removed, the largest stray snapshot is kept around. This covers the case where + // the broker shutdown cleanly and emitted a snapshot file larger than the base offset of the active segment. + + // Create 3 snapshot files at different offsets. + Log.producerSnapshotFile(logDir, 42).createNewFile() + Log.producerSnapshotFile(logDir, 5).createNewFile() + Log.producerSnapshotFile(logDir, 2).createNewFile() + + // claim that we only have one segment with a base offset of 5 + stateManager.removeStraySnapshots(Set(5)) + + // The snapshot file at offset 2 should be considered a stray, but the snapshot at 42 should be kept + // around because it is the largest snapshot. + assertEquals(Some(42), stateManager.latestSnapshotOffset) + assertEquals(Some(5), stateManager.oldestSnapshotOffset) + assertEquals(Seq(5, 42), ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted) + } + + @Test + def testRemoveAllStraySnapshots(): Unit = { + // Test that when stray snapshots are removed, all stray snapshots are removed when the base offset of the largest + // segment exceeds the offset of the largest stray snapshot. Review comment: I think this sentence is a bit confusing. Snapshot 42 is not meant to be a stray snapshot here, only 5 and 2 are. I will try to reword this. ########## File path: core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala ########## @@ -834,6 +834,40 @@ class ProducerStateManagerTest { assertEquals(None, stateManager.lastEntry(producerId).get.currentTxnFirstOffset) } + @Test + def testRemoveStraySnapshotsKeepCleanShutdownSnapshot(): Unit = { + // Test that when stray snapshots are removed, the largest stray snapshot is kept around. This covers the case where + // the broker shutdown cleanly and emitted a snapshot file larger than the base offset of the active segment. + + // Create 3 snapshot files at different offsets. + Log.producerSnapshotFile(logDir, 42).createNewFile() + Log.producerSnapshotFile(logDir, 5).createNewFile() + Log.producerSnapshotFile(logDir, 2).createNewFile() + + // claim that we only have one segment with a base offset of 5 + stateManager.removeStraySnapshots(Set(5)) + + // The snapshot file at offset 2 should be considered a stray, but the snapshot at 42 should be kept + // around because it is the largest snapshot. + assertEquals(Some(42), stateManager.latestSnapshotOffset) + assertEquals(Some(5), stateManager.oldestSnapshotOffset) + assertEquals(Seq(5, 42), ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted) + } + + @Test + def testRemoveAllStraySnapshots(): Unit = { + // Test that when stray snapshots are removed, all stray snapshots are removed when the base offset of the largest + // segment exceeds the offset of the largest stray snapshot. Review comment: Hmm, I think my comment here could be worded better. Offset `42` here is not a "stray", since we provide it along with the list of segmentBaseOffsets to `removeStraySnapshots`. I'll change up the wording on this, thanks! ########## File path: core/src/test/scala/unit/kafka/log/LogTest.scala ########## @@ -1226,6 +1225,104 @@ class LogTest { assertEquals(retainedLastSeqOpt, reloadedLastSeqOpt) } + @Test + def testRetentionDeletesProducerStateSnapshots(): Unit = { + val logConfig = LogTest.createLogConfig(segmentBytes = 2048 * 5, retentionBytes = 0, retentionMs = 1000 * 60, fileDeleteDelayMs = 0) + val log = createLog(logDir, logConfig) + val pid1 = 1L + val epoch = 0.toShort + + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes)), producerId = pid1, + producerEpoch = epoch, sequence = 0), leaderEpoch = 0) + log.roll() + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("b".getBytes)), producerId = pid1, + producerEpoch = epoch, sequence = 1), leaderEpoch = 0) + log.roll() + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("c".getBytes)), producerId = pid1, + producerEpoch = epoch, sequence = 2), leaderEpoch = 0) + + log.updateHighWatermark(log.logEndOffset) + + assertEquals(2, ProducerStateManager.listSnapshotFiles(logDir).size) + // Sleep to breach the retention period + mockTime.sleep(1000 * 60 + 1) + log.deleteOldSegments() + // Sleep to breach the file delete delay and run scheduled file deletion tasks + mockTime.sleep(1) + assertEquals("expect a single producer state snapshot remaining", 1, ProducerStateManager.listSnapshotFiles(logDir).size) + } + + @Test + def testLogStartOffsetMovementDeletesSnapshots(): Unit = { + val logConfig = LogTest.createLogConfig(segmentBytes = 2048 * 5, retentionBytes = -1, fileDeleteDelayMs = 0) + val log = createLog(logDir, logConfig) + val pid1 = 1L + val epoch = 0.toShort + + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes)), producerId = pid1, + producerEpoch = epoch, sequence = 0), leaderEpoch = 0) + log.roll() + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("b".getBytes)), producerId = pid1, + producerEpoch = epoch, sequence = 1), leaderEpoch = 0) + log.roll() + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("c".getBytes)), producerId = pid1, + producerEpoch = epoch, sequence = 2), leaderEpoch = 0) + log.updateHighWatermark(log.logEndOffset) + assertEquals(2, ProducerStateManager.listSnapshotFiles(logDir).size) + + // Increment the log start offset to exclude the first two segments. + log.maybeIncrementLogStartOffset(log.logEndOffset - 1, ClientRecordDeletion) + log.deleteOldSegments() + // Sleep to breach the file delete delay and run scheduled file deletion tasks + mockTime.sleep(1) + assertEquals("expect a single producer state snapshot remaining", 1, ProducerStateManager.listSnapshotFiles(logDir).size) + } + + @Test + def testCompactionDeletesProducerStateSnapshots(): Unit = { + val logConfig = LogTest.createLogConfig(segmentBytes = 2048 * 5, cleanupPolicy = LogConfig.Compact, fileDeleteDelayMs = 0) + val log = createLog(logDir, logConfig) + val pid1 = 1L + val epoch = 0.toShort + val cleaner = new Cleaner(id = 0, + offsetMap = new FakeOffsetMap(Int.MaxValue), + ioBufferSize = 64 * 1024, + maxIoBufferSize = 64 * 1024, + dupBufferLoadFactor = 0.75, + throttler = new Throttler(Double.MaxValue, Long.MaxValue, false, time = mockTime), + time = mockTime, + checkDone = _ => {}) + + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes, "a".getBytes())), producerId = pid1, + producerEpoch = epoch, sequence = 0), leaderEpoch = 0) + log.roll() + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes, "b".getBytes())), producerId = pid1, + producerEpoch = epoch, sequence = 1), leaderEpoch = 0) + log.roll() + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes, "c".getBytes())), producerId = pid1, + producerEpoch = epoch, sequence = 2), leaderEpoch = 0) + log.updateHighWatermark(log.logEndOffset) + assertEquals("expected a snapshot file per segment base offset, except the first segment", log.logSegments.map(_.baseOffset).toSeq.sorted.drop(1), ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted) + assertEquals(2, ProducerStateManager.listSnapshotFiles(logDir).size) + + // Clean segments, this should delete everything except the active segment since there only + // exists the key "a". + cleaner.clean(LogToClean(log.topicPartition, log, 0, log.logEndOffset)) + log.deleteOldSegments() + // Sleep to breach the file delete delay and run scheduled file deletion tasks + mockTime.sleep(1) + assertEquals("expected a snapshot file per segment base offset, excluding the first", log.logSegments.map(_.baseOffset).toSeq.sorted.drop(1), ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted) + } + + @Test + def testLoadingLogCleansOrphanedProducerStateSnapshots(): Unit = { + val orphanedSnapshotFile = Log.producerSnapshotFile(logDir, 42).toPath + Files.createFile(orphanedSnapshotFile) + val logConfig = LogTest.createLogConfig(segmentBytes = 2048 * 5, retentionBytes = -1, fileDeleteDelayMs = 0) + createLog(logDir, logConfig) + assertEquals("expected orphaned producer state snapshot file to be cleaned up", 0, ProducerStateManager.listSnapshotFiles(logDir).size) Review comment: I added some extra context to the existing test and created a new test which verifies from the log that the largest stray snapshot which is within the logs end offset is retained `testLoadingLogKeepsLargestStrayProducerStateSnapshot`. ########## File path: core/src/test/scala/unit/kafka/log/LogTest.scala ########## @@ -782,7 +782,7 @@ class LogTest { } // Retain snapshots for the last 2 segments - ProducerStateManager.deleteSnapshotsBefore(logDir, segmentOffsets(segmentOffsets.size - 2)) + ProducerStateManager.listSnapshotFiles(logDir).filter(_.offset < segmentOffsets(segmentOffsets.size - 2)).foreach(_.deleteIfExists()) Review comment: Yes, it should work if we switch these back to using deleteSnapshotsBefore. Thanks! ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org