This is an automated email from the ASF dual-hosted git repository.

davidarthur pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 924c870  KAFKA-12543: Change RawSnapshotReader ownership model (#10431)
924c870 is described below

commit 924c870fb102041363acc00227f8dba152dc830f
Author: José Armando García Sancio <[email protected]>
AuthorDate: Tue May 18 11:14:17 2021 -0700

    KAFKA-12543: Change RawSnapshotReader ownership model (#10431)
    
    Kafka networking layer doesn't close FileRecords and assumes that they are 
already open when sending them over a channel. To support this pattern this 
commit changes the ownership model for FileRawSnapshotReader so that they are 
owned by KafkaMetadataLog.
    
    Reviewers: dengziming <[email protected]>, David Arthur <[email protected]>, 
Jun Rao <[email protected]>
---
 .../main/scala/kafka/raft/KafkaMetadataLog.scala   | 190 +++++++++++-------
 .../scala/kafka/raft/KafkaMetadataLogTest.scala    |   4 +-
 .../org/apache/kafka/raft/KafkaRaftClient.java     |  81 ++++----
 .../kafka/snapshot/FileRawSnapshotReader.java      |  18 +-
 .../apache/kafka/snapshot/RawSnapshotReader.java   |   4 +-
 .../java/org/apache/kafka/snapshot/Snapshots.java  |  44 ++--
 .../kafka/raft/KafkaRaftClientSnapshotTest.java    | 223 ++++++++++-----------
 .../test/java/org/apache/kafka/raft/MockLog.java   |   3 -
 .../java/org/apache/kafka/raft/MockLogTest.java    |   5 +-
 .../org/apache/kafka/snapshot/SnapshotsTest.java   |   2 +-
 10 files changed, 320 insertions(+), 254 deletions(-)

diff --git a/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala 
b/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala
index d01f1c9..5ebade2 100644
--- a/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala
+++ b/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala
@@ -16,29 +16,30 @@
  */
 package kafka.raft
 
-import java.io.{File, IOException}
-import java.nio.file.{Files, NoSuchFileException}
-import java.util.concurrent.ConcurrentSkipListSet
+import java.io.File
+import java.nio.file.{Files, NoSuchFileException, Path}
 import java.util.{Optional, Properties}
 
 import kafka.api.ApiVersion
 import kafka.log.{AppendOrigin, Log, LogConfig, LogOffsetSnapshot, 
SnapshotGenerated}
 import kafka.server.{BrokerTopicStats, FetchHighWatermark, FetchLogEnd, 
LogDirFailureChannel}
-import kafka.utils.{Logging, Scheduler}
+import kafka.utils.{CoreUtils, Logging, Scheduler}
 import org.apache.kafka.common.record.{MemoryRecords, Records}
-import org.apache.kafka.common.utils.{Time, Utils}
+import org.apache.kafka.common.utils.Time
 import org.apache.kafka.common.{KafkaException, TopicPartition, Uuid}
 import org.apache.kafka.raft.{Isolation, LogAppendInfo, LogFetchInfo, 
LogOffsetMetadata, OffsetAndEpoch, OffsetMetadata, ReplicatedLog}
 import org.apache.kafka.snapshot.{FileRawSnapshotReader, 
FileRawSnapshotWriter, RawSnapshotReader, RawSnapshotWriter, SnapshotPath, 
Snapshots}
 
+import scala.annotation.nowarn
+import scala.collection.mutable
 import scala.compat.java8.OptionConverters._
 
 final class KafkaMetadataLog private (
   log: Log,
   scheduler: Scheduler,
-  // This object needs to be thread-safe because it is used by the 
snapshotting thread to notify the
-  // polling thread when snapshots are created.
-  snapshotIds: ConcurrentSkipListSet[OffsetAndEpoch],
+  // Access to this object needs to be synchronized because it is used by the 
snapshotting thread to notify the
+  // polling thread when snapshots are created. This object is also used to 
store any opened snapshot reader.
+  snapshots: mutable.TreeMap[OffsetAndEpoch, Option[FileRawSnapshotReader]],
   topicPartition: TopicPartition,
   maxFetchSizeInBytes: Int,
   val fileDeleteDelayMs: Long // Visible for testing,
@@ -161,19 +162,24 @@ final class KafkaMetadataLog private (
 
   override def truncateToLatestSnapshot(): Boolean = {
     val latestEpoch = log.latestEpoch.getOrElse(0)
-    latestSnapshotId().asScala match {
-      case Some(snapshotId) if (snapshotId.epoch > latestEpoch ||
-        (snapshotId.epoch == latestEpoch && snapshotId.offset > 
endOffset().offset)) =>
+    val (truncated, forgottenSnapshots) = latestSnapshotId().asScala match {
+      case Some(snapshotId) if (
+          snapshotId.epoch > latestEpoch ||
+          (snapshotId.epoch == latestEpoch && snapshotId.offset > 
endOffset().offset)
+        ) =>
         // Truncate the log fully if the latest snapshot is greater than the 
log end offset
-
         log.truncateFullyAndStartAt(snapshotId.offset)
-        // Delete snapshot after truncating
-        removeSnapshotFilesBefore(snapshotId)
-
-        true
 
-      case _ => false
+        // Forget snapshots less than the log start offset
+        snapshots synchronized {
+          (true, forgetSnapshotsBefore(snapshotId))
+        }
+      case _ =>
+        (false, mutable.TreeMap.empty[OffsetAndEpoch, 
Option[FileRawSnapshotReader]])
     }
+
+    removeSnapshots(forgottenSnapshots)
+    truncated
   }
 
   override def initializeLeaderEpoch(epoch: Int): Unit = {
@@ -242,85 +248,116 @@ final class KafkaMetadataLog private (
   }
 
   override def readSnapshot(snapshotId: OffsetAndEpoch): 
Optional[RawSnapshotReader] = {
-    try {
-      if (snapshotIds.contains(snapshotId)) {
-        Optional.of(FileRawSnapshotReader.open(log.dir.toPath, snapshotId))
-      } else {
-        Optional.empty()
+    snapshots synchronized {
+      val reader = snapshots.get(snapshotId) match {
+        case None =>
+          // Snapshot doesn't exists
+          None
+        case Some(None) =>
+          // Snapshot exists but has never been read before
+          try {
+            val snapshotReader = 
Some(FileRawSnapshotReader.open(log.dir.toPath, snapshotId))
+            snapshots.put(snapshotId, snapshotReader)
+            snapshotReader
+          } catch {
+            case _: NoSuchFileException =>
+              // Snapshot doesn't exists in the data dir; remove
+              val path = Snapshots.snapshotPath(log.dir.toPath, snapshotId)
+              warn(s"Couldn't read $snapshotId; expected to find snapshot file 
$path")
+              snapshots.remove(snapshotId)
+              None
+          }
+        case Some(value) =>
+          // Snapshot exists and it is already open; do nothing
+          value
       }
-    } catch {
-      case _: NoSuchFileException =>
-        Optional.empty()
+
+      reader.asJava.asInstanceOf[Optional[RawSnapshotReader]]
     }
   }
 
   override def latestSnapshotId(): Optional[OffsetAndEpoch] = {
-    val descending = snapshotIds.descendingIterator
-    if (descending.hasNext) {
-      Optional.of(descending.next)
-    } else {
-      Optional.empty()
+    snapshots synchronized {
+      snapshots.lastOption.map { case (snapshotId, _) => snapshotId }.asJava
     }
   }
 
   override def earliestSnapshotId(): Optional[OffsetAndEpoch] = {
-    val ascendingIterator = snapshotIds.iterator
-    if (ascendingIterator.hasNext) {
-      Optional.of(ascendingIterator.next)
-    } else {
-      Optional.empty()
+    snapshots synchronized {
+      snapshots.headOption.map { case (snapshotId, _) => snapshotId }.asJava
     }
   }
 
   override def onSnapshotFrozen(snapshotId: OffsetAndEpoch): Unit = {
-    snapshotIds.add(snapshotId)
+    snapshots synchronized {
+      snapshots.put(snapshotId, None)
+    }
   }
 
   override def deleteBeforeSnapshot(logStartSnapshotId: OffsetAndEpoch): 
Boolean = {
-    latestSnapshotId().asScala match {
-      case Some(snapshotId) if (snapshotIds.contains(logStartSnapshotId) &&
-        startOffset < logStartSnapshotId.offset &&
-        logStartSnapshotId.offset <= snapshotId.offset &&
-        log.maybeIncrementLogStartOffset(logStartSnapshotId.offset, 
SnapshotGenerated)) =>
-        log.deleteOldSegments()
+    val (deleted, forgottenSnapshots) = snapshots synchronized {
+      latestSnapshotId().asScala match {
+        case Some(snapshotId) if (snapshots.contains(logStartSnapshotId) &&
+          startOffset < logStartSnapshotId.offset &&
+          logStartSnapshotId.offset <= snapshotId.offset &&
+          log.maybeIncrementLogStartOffset(logStartSnapshotId.offset, 
SnapshotGenerated)) =>
+
+          // Delete all segments that have a "last offset" less than the log 
start offset
+          log.deleteOldSegments()
 
-        // Delete snapshot after increasing LogStartOffset
-        removeSnapshotFilesBefore(logStartSnapshotId)
+          // Forget snapshots less than the log start offset
+          (true, forgetSnapshotsBefore(logStartSnapshotId))
+        case _ =>
+          (false, mutable.TreeMap.empty[OffsetAndEpoch, 
Option[FileRawSnapshotReader]])
+      }
+    }
 
-        true
+    removeSnapshots(forgottenSnapshots)
+    deleted
+  }
 
-      case _ => false
-    }
+  /**
+   * Forget the snapshots earlier than a given snapshot id and return the 
associated
+   * snapshot readers.
+   *
+   * This method assumes that the lock for `snapshots` is already held.
+   */
+  @nowarn("cat=deprecation") // Needed for TreeMap.until
+  private def forgetSnapshotsBefore(
+    logStartSnapshotId: OffsetAndEpoch
+  ): mutable.TreeMap[OffsetAndEpoch, Option[FileRawSnapshotReader]] = {
+    val expiredSnapshots = snapshots.until(logStartSnapshotId).clone()
+    snapshots --= expiredSnapshots.keys
+
+    expiredSnapshots
   }
 
   /**
-   * Removes all snapshots on the log directory whose epoch and end offset is 
less than the giving epoch and end offset.
+   * Rename the given snapshots on the log directory. Asynchronously, close 
and delete the
+   * given snapshots after some delay.
    */
-  private def removeSnapshotFilesBefore(logStartSnapshotId: OffsetAndEpoch): 
Unit = {
-    val expiredSnapshotIdsIter = snapshotIds.headSet(logStartSnapshotId, 
false).iterator
-    while (expiredSnapshotIdsIter.hasNext) {
-      val snapshotId = expiredSnapshotIdsIter.next()
-      // If snapshotIds contains a snapshot id, the KafkaRaftClient and 
Listener can expect that the snapshot exists
-      // on the file system, so we should first remove snapshotId and then 
delete snapshot file.
-      expiredSnapshotIdsIter.remove()
-
-      val path = Snapshots.snapshotPath(log.dir.toPath, snapshotId)
-      val destination = Snapshots.deleteRename(path, snapshotId)
-      try {
-        Utils.atomicMoveWithFallback(path, destination, false)
-      } catch {
-        case e: IOException =>
-          error(s"Error renaming snapshot file: $path to $destination", e)
-      }
+  private def removeSnapshots(
+    expiredSnapshots: mutable.TreeMap[OffsetAndEpoch, 
Option[FileRawSnapshotReader]]
+  ): Unit = {
+    expiredSnapshots.foreach { case (snapshotId, _) =>
+      Snapshots.markForDelete(log.dir.toPath, snapshotId)
+    }
+
+    if (expiredSnapshots.nonEmpty) {
       scheduler.schedule(
-        "delete-snapshot-file",
-        () => Snapshots.deleteSnapshotIfExists(log.dir.toPath, snapshotId),
-        fileDeleteDelayMs)
+        "delete-snapshot-files",
+        KafkaMetadataLog.deleteSnapshotFiles(log.dir.toPath, expiredSnapshots),
+        fileDeleteDelayMs
+      )
     }
   }
 
   override def close(): Unit = {
     log.close()
+    snapshots synchronized {
+      snapshots.values.flatten.foreach(_.close())
+      snapshots.clear()
+    }
   }
 }
 
@@ -376,8 +413,8 @@ object KafkaMetadataLog {
 
   private def recoverSnapshots(
     log: Log
-  ): ConcurrentSkipListSet[OffsetAndEpoch] = {
-    val snapshotIds = new ConcurrentSkipListSet[OffsetAndEpoch]()
+  ): mutable.TreeMap[OffsetAndEpoch, Option[FileRawSnapshotReader]] = {
+    val snapshots = mutable.TreeMap.empty[OffsetAndEpoch, 
Option[FileRawSnapshotReader]]
     // Scan the log directory; deleting partial snapshots and older snapshot, 
only remembering immutable snapshots start
     // from logStartOffset
     Files
@@ -397,11 +434,22 @@ object KafkaMetadataLog {
             // Delete partial snapshot, deleted snapshot and older snapshot
             Files.deleteIfExists(snapshotPath.path)
           } else {
-            snapshotIds.add(snapshotPath.snapshotId)
+            snapshots.put(snapshotPath.snapshotId, None)
           }
         }
       }
-    snapshotIds
+    snapshots
   }
 
+  private def deleteSnapshotFiles(
+    logDir: Path,
+    expiredSnapshots: mutable.TreeMap[OffsetAndEpoch, 
Option[FileRawSnapshotReader]]
+  ): () => Unit = () => {
+    expiredSnapshots.foreach { case (snapshotId, snapshotReader) =>
+      snapshotReader.foreach { reader =>
+        CoreUtils.swallow(reader.close(), this)
+      }
+      Snapshots.deleteIfExists(logDir, snapshotId)
+    }
+  }
 }
diff --git a/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala 
b/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala
index f9ec73e..c76ccaa 100644
--- a/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala
+++ b/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala
@@ -99,9 +99,7 @@ final class KafkaMetadataLogTest {
       snapshot.freeze()
     }
 
-    TestUtils.resource(log.readSnapshot(snapshotId).get()) { snapshot =>
-      assertEquals(0, snapshot.sizeInBytes())
-    }
+    assertEquals(0, log.readSnapshot(snapshotId).get().sizeInBytes())
   }
 
   @Test
diff --git a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java 
b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
index 190f39b..f1a928d 100644
--- a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
+++ b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
@@ -1246,51 +1246,50 @@ public class KafkaRaftClient<T> implements 
RaftClient<T> {
             );
         }
 
-        try (RawSnapshotReader snapshot = snapshotOpt.get()) {
-            long snapshotSize = snapshot.sizeInBytes();
-            if (partitionSnapshot.position() < 0 || 
partitionSnapshot.position() >= snapshotSize) {
-                return FetchSnapshotResponse.singleton(
-                    log.topicPartition(),
-                    responsePartitionSnapshot -> 
addQuorumLeader(responsePartitionSnapshot)
-                        .setErrorCode(Errors.POSITION_OUT_OF_RANGE.code())
-                );
-            }
-
-            if (partitionSnapshot.position() > Integer.MAX_VALUE) {
-                throw new IllegalStateException(
-                    String.format(
-                        "Trying to fetch a snapshot with size (%s) and a 
position (%s) larger than %s",
-                        snapshotSize,
-                        partitionSnapshot.position(),
-                        Integer.MAX_VALUE
-                    )
-                );
-            }
-
-            int maxSnapshotSize;
-            try {
-                maxSnapshotSize = Math.toIntExact(snapshotSize);
-            } catch (ArithmeticException e) {
-                maxSnapshotSize = Integer.MAX_VALUE;
-            }
-
-            UnalignedRecords records = 
snapshot.slice(partitionSnapshot.position(), Math.min(data.maxBytes(), 
maxSnapshotSize));
-
+        RawSnapshotReader snapshot = snapshotOpt.get();
+        long snapshotSize = snapshot.sizeInBytes();
+        if (partitionSnapshot.position() < 0 || partitionSnapshot.position() 
>= snapshotSize) {
             return FetchSnapshotResponse.singleton(
                 log.topicPartition(),
-                responsePartitionSnapshot -> {
-                    addQuorumLeader(responsePartitionSnapshot)
-                        .snapshotId()
-                        .setEndOffset(snapshotId.offset)
-                        .setEpoch(snapshotId.epoch);
-
-                    return responsePartitionSnapshot
-                        .setSize(snapshotSize)
-                        .setPosition(partitionSnapshot.position())
-                        .setUnalignedRecords(records);
-                }
+                responsePartitionSnapshot -> 
addQuorumLeader(responsePartitionSnapshot)
+                    .setErrorCode(Errors.POSITION_OUT_OF_RANGE.code())
+            );
+        }
+
+        if (partitionSnapshot.position() > Integer.MAX_VALUE) {
+            throw new IllegalStateException(
+                String.format(
+                    "Trying to fetch a snapshot with size (%s) and a position 
(%s) larger than %s",
+                    snapshotSize,
+                    partitionSnapshot.position(),
+                    Integer.MAX_VALUE
+                )
             );
         }
+
+        int maxSnapshotSize;
+        try {
+            maxSnapshotSize = Math.toIntExact(snapshotSize);
+        } catch (ArithmeticException e) {
+            maxSnapshotSize = Integer.MAX_VALUE;
+        }
+
+        UnalignedRecords records = 
snapshot.slice(partitionSnapshot.position(), Math.min(data.maxBytes(), 
maxSnapshotSize));
+
+        return FetchSnapshotResponse.singleton(
+            log.topicPartition(),
+            responsePartitionSnapshot -> {
+                addQuorumLeader(responsePartitionSnapshot)
+                    .snapshotId()
+                    .setEndOffset(snapshotId.offset)
+                    .setEpoch(snapshotId.epoch);
+
+                return responsePartitionSnapshot
+                    .setSize(snapshotSize)
+                    .setPosition(partitionSnapshot.position())
+                    .setUnalignedRecords(records);
+            }
+        );
     }
 
     private boolean handleFetchSnapshotResponse(
diff --git 
a/raft/src/main/java/org/apache/kafka/snapshot/FileRawSnapshotReader.java 
b/raft/src/main/java/org/apache/kafka/snapshot/FileRawSnapshotReader.java
index 820230e..59d3c9c 100644
--- a/raft/src/main/java/org/apache/kafka/snapshot/FileRawSnapshotReader.java
+++ b/raft/src/main/java/org/apache/kafka/snapshot/FileRawSnapshotReader.java
@@ -22,9 +22,10 @@ import org.apache.kafka.common.record.UnalignedRecords;
 import org.apache.kafka.raft.OffsetAndEpoch;
 
 import java.io.IOException;
+import java.io.UncheckedIOException;
 import java.nio.file.Path;
 
-public final class FileRawSnapshotReader implements RawSnapshotReader {
+public final class FileRawSnapshotReader implements RawSnapshotReader, 
AutoCloseable {
     private final FileRecords fileRecords;
     private final OffsetAndEpoch snapshotId;
 
@@ -54,8 +55,19 @@ public final class FileRawSnapshotReader implements 
RawSnapshotReader {
     }
 
     @Override
-    public void close() throws IOException {
-        fileRecords.close();
+    public void close() {
+        try {
+            fileRecords.close();
+        } catch (IOException e) {
+            throw new UncheckedIOException(
+                String.format(
+                    "Unable to close snapshot reader %s at %s",
+                    snapshotId,
+                    fileRecords.file
+                ),
+                e
+            );
+        }
     }
 
     /**
diff --git 
a/raft/src/main/java/org/apache/kafka/snapshot/RawSnapshotReader.java 
b/raft/src/main/java/org/apache/kafka/snapshot/RawSnapshotReader.java
index 506728d..1a51999 100644
--- a/raft/src/main/java/org/apache/kafka/snapshot/RawSnapshotReader.java
+++ b/raft/src/main/java/org/apache/kafka/snapshot/RawSnapshotReader.java
@@ -20,12 +20,10 @@ import org.apache.kafka.common.record.Records;
 import org.apache.kafka.common.record.UnalignedRecords;
 import org.apache.kafka.raft.OffsetAndEpoch;
 
-import java.io.Closeable;
-
 /**
  * Interface for reading snapshots as a sequence of records.
  */
-public interface RawSnapshotReader extends Closeable {
+public interface RawSnapshotReader {
     /**
      * Returns the end offset and epoch for the snapshot.
      */
diff --git a/raft/src/main/java/org/apache/kafka/snapshot/Snapshots.java 
b/raft/src/main/java/org/apache/kafka/snapshot/Snapshots.java
index 575358f..0974869 100644
--- a/raft/src/main/java/org/apache/kafka/snapshot/Snapshots.java
+++ b/raft/src/main/java/org/apache/kafka/snapshot/Snapshots.java
@@ -17,10 +17,12 @@
 package org.apache.kafka.snapshot;
 
 import org.apache.kafka.raft.OffsetAndEpoch;
+import org.apache.kafka.common.utils.Utils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
+import java.io.UncheckedIOException;
 import java.nio.file.Files;
 import java.nio.file.Path;
 import java.text.NumberFormat;
@@ -50,10 +52,6 @@ public final class Snapshots {
         return logDir;
     }
 
-    public static Path snapshotPath(Path logDir, OffsetAndEpoch snapshotId) {
-        return snapshotDir(logDir).resolve(filenameFromSnapshotId(snapshotId) 
+ SUFFIX);
-    }
-
     static String filenameFromSnapshotId(OffsetAndEpoch snapshotId) {
         return String.format("%s-%s", 
OFFSET_FORMATTER.format(snapshotId.offset), 
EPOCH_FORMATTER.format(snapshotId.epoch));
     }
@@ -62,10 +60,14 @@ public final class Snapshots {
         return source.resolveSibling(filenameFromSnapshotId(snapshotId) + 
SUFFIX);
     }
 
-    public static Path deleteRename(Path source, OffsetAndEpoch snapshotId) {
+    static Path deleteRename(Path source, OffsetAndEpoch snapshotId) {
         return source.resolveSibling(filenameFromSnapshotId(snapshotId) + 
DELETE_SUFFIX);
     }
 
+    public static Path snapshotPath(Path logDir, OffsetAndEpoch snapshotId) {
+        return snapshotDir(logDir).resolve(filenameFromSnapshotId(snapshotId) 
+ SUFFIX);
+    }
+
     public static Path createTempFile(Path logDir, OffsetAndEpoch snapshotId) 
throws IOException {
         Path dir = snapshotDir(logDir);
 
@@ -104,18 +106,36 @@ public final class Snapshots {
     }
 
     /**
-     * Delete the snapshot from the filesystem, the caller may firstly rename 
snapshot file to
-     * ${file}.deleted, so we try to delete the file as well as the renamed 
file if exists.
+     * Delete the snapshot from the filesystem.
      */
-    public static boolean deleteSnapshotIfExists(Path logDir, OffsetAndEpoch 
snapshotId) {
-        Path immutablePath = Snapshots.snapshotPath(logDir, snapshotId);
-        Path deletingPath = Snapshots.deleteRename(immutablePath, snapshotId);
+    public static boolean deleteIfExists(Path logDir, OffsetAndEpoch 
snapshotId) {
+        Path immutablePath = snapshotPath(logDir, snapshotId);
+        Path deletedPath = deleteRename(immutablePath, snapshotId);
         try {
-            return Files.deleteIfExists(immutablePath) | 
Files.deleteIfExists(deletingPath);
+            return Files.deleteIfExists(immutablePath) | 
Files.deleteIfExists(deletedPath);
         } catch (IOException e) {
-            log.error("Error deleting snapshot file " + deletingPath, e);
+            log.error("Error deleting snapshot files {} and {}", 
immutablePath, deletedPath, e);
             return false;
         }
     }
 
+    /**
+     * Mark a snapshot for deletion by renaming with the deleted suffix
+     */
+    public static void markForDelete(Path logDir, OffsetAndEpoch snapshotId) {
+        Path immutablePath = snapshotPath(logDir, snapshotId);
+        Path deletedPath = deleteRename(immutablePath, snapshotId);
+        try {
+            Utils.atomicMoveWithFallback(immutablePath, deletedPath, false);
+        } catch (IOException e) {
+            throw new UncheckedIOException(
+                String.format(
+                    "Error renaming snapshot file from %s to %s",
+                    immutablePath,
+                    deletedPath
+                ),
+                e
+            );
+        }
+    }
 }
diff --git 
a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientSnapshotTest.java 
b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientSnapshotTest.java
index 6fd5147..2acf287 100644
--- a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientSnapshotTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientSnapshotTest.java
@@ -551,32 +551,31 @@ final public class KafkaRaftClientSnapshotTest {
             snapshot.freeze();
         }
 
-        try (RawSnapshotReader snapshot = 
context.log.readSnapshot(snapshotId).get()) {
-            context.deliverRequest(
-                fetchSnapshotRequest(
-                    context.metadataPartition,
-                    epoch,
-                    snapshotId,
-                    Integer.MAX_VALUE,
-                    0
-                )
-            );
-
-            context.client.poll();
-
-            FetchSnapshotResponseData.PartitionSnapshot response = context
-                .assertSentFetchSnapshotResponse(context.metadataPartition)
-                .get();
-
-            assertEquals(Errors.NONE, Errors.forCode(response.errorCode()));
-            assertEquals(snapshot.sizeInBytes(), response.size());
-            assertEquals(0, response.position());
-            assertEquals(snapshot.sizeInBytes(), 
response.unalignedRecords().sizeInBytes());
-
-            UnalignedMemoryRecords memoryRecords = (UnalignedMemoryRecords) 
snapshot.slice(0, Math.toIntExact(snapshot.sizeInBytes()));
-
-            assertEquals(memoryRecords.buffer(), ((UnalignedMemoryRecords) 
response.unalignedRecords()).buffer());
-        }
+        RawSnapshotReader snapshot = 
context.log.readSnapshot(snapshotId).get();
+        context.deliverRequest(
+            fetchSnapshotRequest(
+                context.metadataPartition,
+                epoch,
+                snapshotId,
+                Integer.MAX_VALUE,
+                0
+            )
+        );
+
+        context.client.poll();
+
+        FetchSnapshotResponseData.PartitionSnapshot response = context
+            .assertSentFetchSnapshotResponse(context.metadataPartition)
+            .get();
+
+        assertEquals(Errors.NONE, Errors.forCode(response.errorCode()));
+        assertEquals(snapshot.sizeInBytes(), response.size());
+        assertEquals(0, response.position());
+        assertEquals(snapshot.sizeInBytes(), 
response.unalignedRecords().sizeInBytes());
+
+        UnalignedMemoryRecords memoryRecords = (UnalignedMemoryRecords) 
snapshot.slice(0, Math.toIntExact(snapshot.sizeInBytes()));
+
+        assertEquals(memoryRecords.buffer(), ((UnalignedMemoryRecords) 
response.unalignedRecords()).buffer());
     }
 
     @Test
@@ -594,62 +593,61 @@ final public class KafkaRaftClientSnapshotTest {
             snapshot.freeze();
         }
 
-        try (RawSnapshotReader snapshot = 
context.log.readSnapshot(snapshotId).get()) {
-            // Fetch half of the snapshot
-            context.deliverRequest(
-                fetchSnapshotRequest(
-                    context.metadataPartition,
-                    epoch,
-                    snapshotId,
-                    Math.toIntExact(snapshot.sizeInBytes() / 2),
-                    0
-                )
-            );
-
-            context.client.poll();
-
-            FetchSnapshotResponseData.PartitionSnapshot response = context
-                .assertSentFetchSnapshotResponse(context.metadataPartition)
-                .get();
-
-            assertEquals(Errors.NONE, Errors.forCode(response.errorCode()));
-            assertEquals(snapshot.sizeInBytes(), response.size());
-            assertEquals(0, response.position());
-            assertEquals(snapshot.sizeInBytes() / 2, 
response.unalignedRecords().sizeInBytes());
-
-            UnalignedMemoryRecords memoryRecords = (UnalignedMemoryRecords) 
snapshot.slice(0, Math.toIntExact(snapshot.sizeInBytes()));
-            ByteBuffer snapshotBuffer = memoryRecords.buffer();
-
-            ByteBuffer responseBuffer = 
ByteBuffer.allocate(Math.toIntExact(snapshot.sizeInBytes()));
-            responseBuffer.put(((UnalignedMemoryRecords) 
response.unalignedRecords()).buffer());
-
-            ByteBuffer expectedBytes = snapshotBuffer.duplicate();
-            expectedBytes.limit(Math.toIntExact(snapshot.sizeInBytes() / 2));
-
-            assertEquals(expectedBytes, responseBuffer.duplicate().flip());
-
-            // Fetch the remainder of the snapshot
-            context.deliverRequest(
-                fetchSnapshotRequest(
-                    context.metadataPartition,
-                    epoch,
-                    snapshotId,
-                    Integer.MAX_VALUE,
-                    responseBuffer.position()
-                )
-            );
-
-            context.client.poll();
-
-            response = 
context.assertSentFetchSnapshotResponse(context.metadataPartition).get();
-            assertEquals(Errors.NONE, Errors.forCode(response.errorCode()));
-            assertEquals(snapshot.sizeInBytes(), response.size());
-            assertEquals(responseBuffer.position(), response.position());
-            assertEquals(snapshot.sizeInBytes() - (snapshot.sizeInBytes() / 
2), response.unalignedRecords().sizeInBytes());
-
-            responseBuffer.put(((UnalignedMemoryRecords) 
response.unalignedRecords()).buffer());
-            assertEquals(snapshotBuffer, responseBuffer.flip());
-        }
+        RawSnapshotReader snapshot = 
context.log.readSnapshot(snapshotId).get();
+        // Fetch half of the snapshot
+        context.deliverRequest(
+            fetchSnapshotRequest(
+                context.metadataPartition,
+                epoch,
+                snapshotId,
+                Math.toIntExact(snapshot.sizeInBytes() / 2),
+                0
+            )
+        );
+
+        context.client.poll();
+
+        FetchSnapshotResponseData.PartitionSnapshot response = context
+            .assertSentFetchSnapshotResponse(context.metadataPartition)
+            .get();
+
+        assertEquals(Errors.NONE, Errors.forCode(response.errorCode()));
+        assertEquals(snapshot.sizeInBytes(), response.size());
+        assertEquals(0, response.position());
+        assertEquals(snapshot.sizeInBytes() / 2, 
response.unalignedRecords().sizeInBytes());
+
+        UnalignedMemoryRecords memoryRecords = (UnalignedMemoryRecords) 
snapshot.slice(0, Math.toIntExact(snapshot.sizeInBytes()));
+        ByteBuffer snapshotBuffer = memoryRecords.buffer();
+
+        ByteBuffer responseBuffer = 
ByteBuffer.allocate(Math.toIntExact(snapshot.sizeInBytes()));
+        responseBuffer.put(((UnalignedMemoryRecords) 
response.unalignedRecords()).buffer());
+
+        ByteBuffer expectedBytes = snapshotBuffer.duplicate();
+        expectedBytes.limit(Math.toIntExact(snapshot.sizeInBytes() / 2));
+
+        assertEquals(expectedBytes, responseBuffer.duplicate().flip());
+
+        // Fetch the remainder of the snapshot
+        context.deliverRequest(
+            fetchSnapshotRequest(
+                context.metadataPartition,
+                epoch,
+                snapshotId,
+                Integer.MAX_VALUE,
+                responseBuffer.position()
+            )
+        );
+
+        context.client.poll();
+
+        response = 
context.assertSentFetchSnapshotResponse(context.metadataPartition).get();
+        assertEquals(Errors.NONE, Errors.forCode(response.errorCode()));
+        assertEquals(snapshot.sizeInBytes(), response.size());
+        assertEquals(responseBuffer.position(), response.position());
+        assertEquals(snapshot.sizeInBytes() - (snapshot.sizeInBytes() / 2), 
response.unalignedRecords().sizeInBytes());
+
+        responseBuffer.put(((UnalignedMemoryRecords) 
response.unalignedRecords()).buffer());
+        assertEquals(snapshotBuffer, responseBuffer.flip());
     }
 
     @Test
@@ -714,24 +712,23 @@ final public class KafkaRaftClientSnapshotTest {
         assertEquals(epoch, response.currentLeader().leaderEpoch());
         assertEquals(localId, response.currentLeader().leaderId());
 
-        try (RawSnapshotReader snapshot = 
context.log.readSnapshot(snapshotId).get()) {
-            context.deliverRequest(
-                fetchSnapshotRequest(
-                    context.metadataPartition,
-                    epoch,
-                    snapshotId,
-                    Integer.MAX_VALUE,
-                    snapshot.sizeInBytes()
-                )
-            );
-
-            context.client.poll();
-
-            response = 
context.assertSentFetchSnapshotResponse(context.metadataPartition).get();
-            assertEquals(Errors.POSITION_OUT_OF_RANGE, 
Errors.forCode(response.errorCode()));
-            assertEquals(epoch, response.currentLeader().leaderEpoch());
-            assertEquals(localId, response.currentLeader().leaderId());
-        }
+        RawSnapshotReader snapshot = 
context.log.readSnapshot(snapshotId).get();
+        context.deliverRequest(
+            fetchSnapshotRequest(
+                context.metadataPartition,
+                epoch,
+                snapshotId,
+                Integer.MAX_VALUE,
+                snapshot.sizeInBytes()
+            )
+        );
+
+        context.client.poll();
+
+        response = 
context.assertSentFetchSnapshotResponse(context.metadataPartition).get();
+        assertEquals(Errors.POSITION_OUT_OF_RANGE, 
Errors.forCode(response.errorCode()));
+        assertEquals(epoch, response.currentLeader().leaderEpoch());
+        assertEquals(localId, response.currentLeader().leaderId());
     }
 
     @Test
@@ -909,15 +906,14 @@ final public class KafkaRaftClientSnapshotTest {
         context.assertFetchRequestData(fetchRequest, epoch, snapshotId.offset, 
snapshotId.epoch);
 
         // Check that the snapshot was written to the log
-        try (RawSnapshotReader snapshot = 
context.log.readSnapshot(snapshotId).get()) {
-            assertEquals(memorySnapshot.buffer().remaining(), 
snapshot.sizeInBytes());
-            SnapshotWriterReaderTest.assertSnapshot(Arrays.asList(records), 
snapshot);
-        }
+        RawSnapshotReader snapshot = 
context.log.readSnapshot(snapshotId).get();
+        assertEquals(memorySnapshot.buffer().remaining(), 
snapshot.sizeInBytes());
+        SnapshotWriterReaderTest.assertSnapshot(Arrays.asList(records), 
snapshot);
 
         // Check that listener was notified of the new snapshot
-        try (SnapshotReader<String> snapshot = 
context.listener.drainHandledSnapshot().get()) {
-            assertEquals(snapshotId, snapshot.snapshotId());
-            SnapshotWriterReaderTest.assertSnapshot(Arrays.asList(records), 
snapshot);
+        try (SnapshotReader<String> reader = 
context.listener.drainHandledSnapshot().get()) {
+            assertEquals(snapshotId, reader.snapshotId());
+            SnapshotWriterReaderTest.assertSnapshot(Arrays.asList(records), 
reader);
         }
     }
 
@@ -1013,15 +1009,14 @@ final public class KafkaRaftClientSnapshotTest {
         context.assertFetchRequestData(fetchRequest, epoch, snapshotId.offset, 
snapshotId.epoch);
 
         // Check that the snapshot was written to the log
-        try (RawSnapshotReader snapshot = 
context.log.readSnapshot(snapshotId).get()) {
-            assertEquals(memorySnapshot.buffer().remaining(), 
snapshot.sizeInBytes());
-            SnapshotWriterReaderTest.assertSnapshot(Arrays.asList(records), 
snapshot);
-        }
+        RawSnapshotReader snapshot = 
context.log.readSnapshot(snapshotId).get();
+        assertEquals(memorySnapshot.buffer().remaining(), 
snapshot.sizeInBytes());
+        SnapshotWriterReaderTest.assertSnapshot(Arrays.asList(records), 
snapshot);
 
         // Check that listener was notified of the new snapshot
-        try (SnapshotReader<String> snapshot = 
context.listener.drainHandledSnapshot().get()) {
-            assertEquals(snapshotId, snapshot.snapshotId());
-            SnapshotWriterReaderTest.assertSnapshot(Arrays.asList(records), 
snapshot);
+        try (SnapshotReader<String> reader = 
context.listener.drainHandledSnapshot().get()) {
+            assertEquals(snapshotId, reader.snapshotId());
+            SnapshotWriterReaderTest.assertSnapshot(Arrays.asList(records), 
reader);
         }
     }
 
diff --git a/raft/src/test/java/org/apache/kafka/raft/MockLog.java 
b/raft/src/test/java/org/apache/kafka/raft/MockLog.java
index bf03a06..5feb943 100644
--- a/raft/src/test/java/org/apache/kafka/raft/MockLog.java
+++ b/raft/src/test/java/org/apache/kafka/raft/MockLog.java
@@ -709,8 +709,5 @@ public class MockLog implements ReplicatedLog {
         public Records records() {
             return data;
         }
-
-        @Override
-        public void close() {}
     }
 }
diff --git a/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java 
b/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java
index 35e13ce..4942139 100644
--- a/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java
@@ -441,9 +441,8 @@ public class MockLogTest {
             snapshot.freeze();
         }
 
-        try (RawSnapshotReader snapshot = log.readSnapshot(snapshotId).get()) {
-            assertEquals(0, snapshot.sizeInBytes());
-        }
+        RawSnapshotReader snapshot = log.readSnapshot(snapshotId).get();
+        assertEquals(0, snapshot.sizeInBytes());
     }
 
     @Test
diff --git a/raft/src/test/java/org/apache/kafka/snapshot/SnapshotsTest.java 
b/raft/src/test/java/org/apache/kafka/snapshot/SnapshotsTest.java
index 7960a83..ae89543 100644
--- a/raft/src/test/java/org/apache/kafka/snapshot/SnapshotsTest.java
+++ b/raft/src/test/java/org/apache/kafka/snapshot/SnapshotsTest.java
@@ -118,7 +118,7 @@ final public class SnapshotsTest {
                 // rename snapshot before deleting
                 Utils.atomicMoveWithFallback(snapshotPath, 
Snapshots.deleteRename(snapshotPath, snapshotId), false);
 
-            assertTrue(Snapshots.deleteSnapshotIfExists(logDirPath, 
snapshot.snapshotId()));
+            assertTrue(Snapshots.deleteIfExists(logDirPath, 
snapshot.snapshotId()));
             assertFalse(Files.exists(snapshotPath));
             assertFalse(Files.exists(Snapshots.deleteRename(snapshotPath, 
snapshotId)));
         }

Reply via email to