This is an automated email from the ASF dual-hosted git repository.

satishd pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new da2e8dce71a KAFKA-14551 Move/Rewrite LeaderEpochFileCache and its 
dependencies to the storage module. (#13046)
da2e8dce71a is described below

commit da2e8dce71a86143d0545a84b5449aabde48a44c
Author: Satish Duggana <[email protected]>
AuthorDate: Tue Feb 7 15:37:23 2023 +0530

    KAFKA-14551 Move/Rewrite LeaderEpochFileCache and its dependencies to the 
storage module. (#13046)
    
    KAFKA-14551 Move/Rewrite LeaderEpochFileCache and its dependencies to the 
storage module.
    
    For broader context on this change, you may want to look at KAFKA-14470: 
Move log layer to the storage module
    
    Reviewers: Ismael Juma <[email protected]>, Jun Rao <[email protected]>, 
Alexandre Dupriez <[email protected]>
---
 checkstyle/import-control.xml                      |   8 +-
 core/src/main/scala/kafka/log/LogLoader.scala      |   2 +-
 core/src/main/scala/kafka/log/LogSegment.scala     |  18 +-
 core/src/main/scala/kafka/log/UnifiedLog.scala     |  56 +--
 .../scala/kafka/log/remote/RemoteLogManager.scala  |   8 +-
 .../scala/kafka/server/LocalLeaderEndPoint.scala   |   6 +-
 .../scala/kafka/server/ReplicaFetcherThread.scala  |  11 +-
 .../CheckpointFileWithFailureHandler.scala         |  56 ---
 .../checkpoints/LeaderEpochCheckpointFile.scala    |  74 ----
 .../server/checkpoints/OffsetCheckpointFile.scala  |  18 +-
 .../kafka/server/epoch/LeaderEpochFileCache.scala  | 343 -----------------
 .../unit/kafka/cluster/PartitionLockTest.scala     |   2 +-
 .../scala/unit/kafka/cluster/PartitionTest.scala   |   9 +-
 .../test/scala/unit/kafka/log/LogLoaderTest.scala  |  10 +-
 .../test/scala/unit/kafka/log/LogSegmentTest.scala |  21 +-
 .../test/scala/unit/kafka/log/LogTestUtils.scala   |   5 +-
 .../test/scala/unit/kafka/log/UnifiedLogTest.scala |  41 ++-
 .../kafka/log/remote/RemoteLogManagerTest.scala    |  16 +-
 ...EpochCheckpointFileWithFailureHandlerTest.scala |  15 +-
 ...ffsetCheckpointFileWithFailureHandlerTest.scala |   4 +-
 ...chDrivenReplicationProtocolAcceptanceTest.scala |  36 +-
 .../server/epoch/LeaderEpochFileCacheTest.scala    | 410 +++++++++++----------
 .../kafka/server/log/internals/EpochEntry.java     |  51 +++
 .../CheckpointFileWithFailureHandler.java          |  63 ++++
 .../checkpoint/LeaderEpochCheckpoint.java          |  29 ++
 .../checkpoint/LeaderEpochCheckpointFile.java      |  78 ++++
 .../internals/epoch/LeaderEpochFileCache.java      | 403 ++++++++++++++++++++
 27 files changed, 990 insertions(+), 803 deletions(-)

diff --git a/checkstyle/import-control.xml b/checkstyle/import-control.xml
index bc89d239e70..aa9917dfe6e 100644
--- a/checkstyle/import-control.xml
+++ b/checkstyle/import-control.xml
@@ -372,7 +372,7 @@
       <allow pkg="org.apache.kafka.server.log" />
       <allow pkg="org.apache.kafka.server.record" />
       <allow pkg="org.apache.kafka.test" />
-
+      <allow pkg="org.apache.kafka.storage"/>
       <subpackage name="remote">
         <allow pkg="scala.collection" />
       </subpackage>
@@ -380,6 +380,12 @@
     </subpackage>
   </subpackage>
 
+  <subpackage name="storage.internals">
+    <allow pkg="org.apache.kafka.server"/>
+    <allow pkg="org.apache.kafka.storage.internals"/>
+    <allow pkg="org.apache.kafka.common" />
+  </subpackage>
+
   <subpackage name="shell">
     <allow pkg="com.fasterxml.jackson" />
     <allow pkg="kafka.raft"/>
diff --git a/core/src/main/scala/kafka/log/LogLoader.scala 
b/core/src/main/scala/kafka/log/LogLoader.scala
index 1c1b2c9c037..81dfdc44547 100644
--- a/core/src/main/scala/kafka/log/LogLoader.scala
+++ b/core/src/main/scala/kafka/log/LogLoader.scala
@@ -21,7 +21,6 @@ import java.io.{File, IOException}
 import java.nio.file.{Files, NoSuchFileException}
 import kafka.common.LogSegmentOffsetOverflowException
 import kafka.log.UnifiedLog.{CleanedFileSuffix, DeletedFileSuffix, 
SwapFileSuffix, isIndexFile, isLogFile, offsetFromFile}
-import kafka.server.epoch.LeaderEpochFileCache
 import kafka.utils.Logging
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.errors.InvalidOffsetException
@@ -29,6 +28,7 @@ import org.apache.kafka.common.utils.{Time, Utils}
 import org.apache.kafka.snapshot.Snapshots
 import org.apache.kafka.server.log.internals.{CorruptIndexException, 
LoadedLogOffsets, LogConfig, LogDirFailureChannel, LogOffsetMetadata}
 import org.apache.kafka.server.util.Scheduler
+import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache
 
 import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
 import scala.collection.{Set, mutable}
diff --git a/core/src/main/scala/kafka/log/LogSegment.scala 
b/core/src/main/scala/kafka/log/LogSegment.scala
index 54a04f8aa7e..f38a3ac353a 100644
--- a/core/src/main/scala/kafka/log/LogSegment.scala
+++ b/core/src/main/scala/kafka/log/LogSegment.scala
@@ -18,13 +18,8 @@ package kafka.log
 
 import com.yammer.metrics.core.Timer
 
-import java.io.{File, IOException}
-import java.nio.file.{Files, NoSuchFileException}
-import java.nio.file.attribute.FileTime
-import java.util.concurrent.TimeUnit
 import kafka.common.LogSegmentOffsetOverflowException
 import kafka.metrics.KafkaMetricsGroup
-import kafka.server.epoch.LeaderEpochFileCache
 import kafka.utils._
 import org.apache.kafka.common.InvalidRecordException
 import org.apache.kafka.common.errors.CorruptRecordException
@@ -32,8 +27,13 @@ import 
org.apache.kafka.common.record.FileRecords.{LogOffsetPosition, TimestampA
 import org.apache.kafka.common.record._
 import org.apache.kafka.common.utils.{BufferSupplier, Time, Utils}
 import org.apache.kafka.server.log.internals.{AbortedTxn, AppendOrigin, 
CompletedTxn, LazyIndex, LogConfig, LogOffsetMetadata, OffsetIndex, 
OffsetPosition, TimeIndex, TimestampOffset, TransactionIndex, 
TxnIndexSearchResult, FetchDataInfo}
+import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache
 
+import java.io.{File, IOException}
+import java.nio.file.attribute.FileTime
+import java.nio.file.{Files, NoSuchFileException}
 import java.util.Optional
+import java.util.concurrent.TimeUnit
 import scala.compat.java8.OptionConverters._
 import scala.jdk.CollectionConverters._
 import scala.math._
@@ -249,13 +249,13 @@ class LogSegment private[log] (val log: FileRecords,
     if (batch.hasProducerId) {
       val producerId = batch.producerId
       val appendInfo = producerStateManager.prepareUpdate(producerId, origin = 
AppendOrigin.REPLICATION)
-      val maybeCompletedTxn = appendInfo.append(batch, 
Optional.empty()).asScala
+      val maybeCompletedTxn = appendInfo.append(batch, Optional.empty())
       producerStateManager.update(appendInfo)
-      maybeCompletedTxn.foreach { completedTxn =>
+      maybeCompletedTxn.ifPresent(completedTxn => {
         val lastStableOffset = 
producerStateManager.lastStableOffset(completedTxn)
         updateTxnIndex(completedTxn, lastStableOffset)
         producerStateManager.completeTxn(completedTxn)
-      }
+      })
     }
     producerStateManager.updateMapEndOffset(batch.lastOffset + 1)
   }
@@ -363,7 +363,7 @@ class LogSegment private[log] (val log: FileRecords,
 
         if (batch.magic >= RecordBatch.MAGIC_VALUE_V2) {
           leaderEpochCache.foreach { cache =>
-            if (batch.partitionLeaderEpoch >= 0 && 
cache.latestEpoch.forall(batch.partitionLeaderEpoch > _))
+            if (batch.partitionLeaderEpoch >= 0 && 
cache.latestEpoch.asScala.forall(batch.partitionLeaderEpoch > _))
               cache.assign(batch.partitionLeaderEpoch, batch.baseOffset)
           }
           updateProducerState(producerStateManager, batch)
diff --git a/core/src/main/scala/kafka/log/UnifiedLog.scala 
b/core/src/main/scala/kafka/log/UnifiedLog.scala
index 78c01a8c218..fcd59fa2c9d 100644
--- a/core/src/main/scala/kafka/log/UnifiedLog.scala
+++ b/core/src/main/scala/kafka/log/UnifiedLog.scala
@@ -26,8 +26,6 @@ import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
 import kafka.common.{OffsetsOutOfOrderException, 
UnexpectedAppendOffsetException}
 import kafka.log.remote.RemoteLogManager
 import kafka.metrics.KafkaMetricsGroup
-import kafka.server.checkpoints.LeaderEpochCheckpointFile
-import kafka.server.epoch.LeaderEpochFileCache
 import kafka.server.{BrokerTopicMetrics, BrokerTopicStats, OffsetAndEpoch, 
PartitionMetadataFile, RequestLocal}
 import kafka.utils._
 import org.apache.kafka.common.errors._
@@ -42,10 +40,12 @@ import org.apache.kafka.common.utils.{PrimitiveRef, Time, 
Utils}
 import org.apache.kafka.common.{InvalidRecordException, KafkaException, 
TopicPartition, Uuid}
 import org.apache.kafka.server.common.MetadataVersion
 import org.apache.kafka.server.common.MetadataVersion.IBP_0_10_0_IV0
-import org.apache.kafka.server.log.internals.{AbortedTxn, AppendOrigin, 
BatchMetadata, CompletedTxn, FetchDataInfo, FetchIsolation, LastRecord, 
LogConfig, LogDirFailureChannel, LogOffsetMetadata, LogValidator, 
ProducerAppendInfo}
+import org.apache.kafka.server.log.internals.{AbortedTxn, AppendOrigin, 
BatchMetadata, CompletedTxn, EpochEntry, FetchDataInfo, FetchIsolation, 
LastRecord, LogConfig, LogDirFailureChannel, LogOffsetMetadata, LogValidator, 
ProducerAppendInfo}
 import 
org.apache.kafka.server.log.remote.metadata.storage.TopicBasedRemoteLogMetadataManagerConfig
 import org.apache.kafka.server.record.BrokerCompressionType
 import org.apache.kafka.server.util.Scheduler
+import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpointFile
+import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache
 
 import scala.annotation.nowarn
 import scala.collection.mutable.ListBuffer
@@ -1007,11 +1007,12 @@ class UnifiedLog(@volatile var logStartOffset: Long,
     }
   }
 
-  def latestEpoch: Option[Int] = leaderEpochCache.flatMap(_.latestEpoch)
+  def latestEpoch: Option[Int] = 
leaderEpochCache.flatMap(_.latestEpoch.asScala)
 
   def endOffsetForEpoch(leaderEpoch: Int): Option[OffsetAndEpoch] = {
     leaderEpochCache.flatMap { cache =>
-      val (foundEpoch, foundOffset) = cache.endOffsetFor(leaderEpoch, 
logEndOffset)
+      val entry = cache.endOffsetFor(leaderEpoch, logEndOffset)
+      val (foundEpoch, foundOffset) = (entry.getKey(), entry.getValue())
       if (foundOffset == UNDEFINED_EPOCH_OFFSET)
         None
       else
@@ -1284,6 +1285,16 @@ class UnifiedLog(@volatile var logStartOffset: Long,
     maybeHandleIOException(s"Error while fetching offset by timestamp for 
$topicPartition in dir ${dir.getParent}") {
       debug(s"Searching offset for timestamp $targetTimestamp")
 
+      def latestEpochAsOptional(leaderEpochCache: 
Option[LeaderEpochFileCache]): Optional[Integer] = {
+        leaderEpochCache match {
+          case Some(cache) => {
+            val latestEpoch = cache.latestEpoch()
+            if (latestEpoch.isPresent) Optional.of(latestEpoch.getAsInt) else 
Optional.empty[Integer]()
+          }
+          case None => Optional.empty[Integer]()
+        }
+      }
+
       if (config.messageFormatVersion.isLessThan(IBP_0_10_0_IV0) &&
         targetTimestamp != ListOffsetsRequest.EARLIEST_TIMESTAMP &&
         targetTimestamp != ListOffsetsRequest.EARLIEST_LOCAL_TIMESTAMP &&
@@ -1298,36 +1309,37 @@ class UnifiedLog(@volatile var logStartOffset: Long,
         // The first cached epoch usually corresponds to the log start offset, 
but we have to verify this since
         // it may not be true following a message format version bump as the 
epoch will not be available for
         // log entries written in the older format.
-        val earliestEpochEntry = leaderEpochCache.flatMap(_.earliestEntry)
-        val epochOpt = earliestEpochEntry match {
-          case Some(entry) if entry.startOffset <= logStartOffset => 
Optional.of[Integer](entry.epoch)
-          case _ => Optional.empty[Integer]()
-        }
+        val earliestEpochEntry = 
leaderEpochCache.asJava.flatMap(_.earliestEntry())
+        val epochOpt = if (earliestEpochEntry.isPresent && 
earliestEpochEntry.get().startOffset <= logStartOffset) {
+          Optional.of[Integer](earliestEpochEntry.get().epoch)
+        } else Optional.empty[Integer]()
+
         Some(new TimestampAndOffset(RecordBatch.NO_TIMESTAMP, logStartOffset, 
epochOpt))
       } else if (targetTimestamp == 
ListOffsetsRequest.EARLIEST_LOCAL_TIMESTAMP) {
         val curLocalLogStartOffset = localLogStartOffset()
-        val earliestLocalLogEpochEntry = leaderEpochCache.flatMap(cache =>
-          
cache.epochForOffset(curLocalLogStartOffset).flatMap(cache.epochEntry))
-        val epochOpt = earliestLocalLogEpochEntry match {
-          case Some(entry) if entry.startOffset <= curLocalLogStartOffset => 
Optional.of[Integer](entry.epoch)
-          case _ => Optional.empty[Integer]()
-        }
+
+        val earliestLocalLogEpochEntry = leaderEpochCache.asJava.flatMap(cache 
=> {
+          val epoch = cache.epochForOffset(curLocalLogStartOffset)
+          if (epoch.isPresent) (cache.epochEntry(epoch.getAsInt)) else 
Optional.empty[EpochEntry]()
+        })
+
+        val epochOpt = if (earliestLocalLogEpochEntry.isPresent && 
earliestLocalLogEpochEntry.get().startOffset <= curLocalLogStartOffset)
+          Optional.of[Integer](earliestLocalLogEpochEntry.get().epoch)
+        else Optional.empty[Integer]()
+
         Some(new TimestampAndOffset(RecordBatch.NO_TIMESTAMP, 
curLocalLogStartOffset, epochOpt))
       } else if (targetTimestamp == ListOffsetsRequest.LATEST_TIMESTAMP) {
-        val latestEpochOpt = 
leaderEpochCache.flatMap(_.latestEpoch).map(_.asInstanceOf[Integer])
-        val epochOptional = Optional.ofNullable(latestEpochOpt.orNull)
-        Some(new TimestampAndOffset(RecordBatch.NO_TIMESTAMP, logEndOffset, 
epochOptional))
+        Some(new TimestampAndOffset(RecordBatch.NO_TIMESTAMP, logEndOffset, 
latestEpochAsOptional(leaderEpochCache)))
       } else if (targetTimestamp == ListOffsetsRequest.MAX_TIMESTAMP) {
         // Cache to avoid race conditions. `toBuffer` is faster than most 
alternatives and provides
         // constant time access while being safe to use with concurrent 
collections unlike `toArray`.
         val segmentsCopy = logSegments.toBuffer
         val latestTimestampSegment = segmentsCopy.maxBy(_.maxTimestampSoFar)
-        val latestEpochOpt = 
leaderEpochCache.flatMap(_.latestEpoch).map(_.asInstanceOf[Integer])
-        val epochOptional = Optional.ofNullable(latestEpochOpt.orNull)
         val latestTimestampAndOffset = 
latestTimestampSegment.maxTimestampAndOffsetSoFar
+
         Some(new TimestampAndOffset(latestTimestampAndOffset.timestamp,
           latestTimestampAndOffset.offset,
-          epochOptional))
+          latestEpochAsOptional(leaderEpochCache)))
       } else {
         // We need to search the first segment whose largest timestamp is >= 
the target timestamp if there is one.
         val remoteOffset = if (remoteLogEnabled()) {
diff --git a/core/src/main/scala/kafka/log/remote/RemoteLogManager.scala 
b/core/src/main/scala/kafka/log/remote/RemoteLogManager.scala
index 8324028c5ed..8394c74afbd 100644
--- a/core/src/main/scala/kafka/log/remote/RemoteLogManager.scala
+++ b/core/src/main/scala/kafka/log/remote/RemoteLogManager.scala
@@ -19,14 +19,14 @@ package kafka.log.remote
 import kafka.cluster.Partition
 import kafka.metrics.KafkaMetricsGroup
 import kafka.server.KafkaConfig
-import kafka.server.epoch.LeaderEpochFileCache
 import kafka.utils.Logging
 import org.apache.kafka.common._
 import org.apache.kafka.common.record.FileRecords.TimestampAndOffset
 import org.apache.kafka.common.record.{RecordBatch, RemoteLogInputStream}
 import org.apache.kafka.common.utils.{ChildFirstClassLoader, Utils}
 import 
org.apache.kafka.server.log.remote.metadata.storage.ClassLoaderAwareRemoteLogMetadataManager
-import 
org.apache.kafka.server.log.remote.storage.{ClassLoaderAwareRemoteStorageManager,
 RemoteLogManagerConfig, RemoteLogMetadataManager, RemoteLogSegmentMetadata, 
RemoteStorageManager}
+import org.apache.kafka.server.log.remote.storage._
+import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache
 
 import java.io.{Closeable, InputStream}
 import java.security.{AccessController, PrivilegedAction}
@@ -256,8 +256,8 @@ class RemoteLogManager(rlmConfig: RemoteLogManagerConfig,
 
     // Get the respective epoch in which the starting-offset exists.
     var maybeEpoch = leaderEpochCache.epochForOffset(startingOffset)
-    while (maybeEpoch.nonEmpty) {
-      val epoch = maybeEpoch.get
+    while (maybeEpoch.isPresent) {
+      val epoch = maybeEpoch.getAsInt
       remoteLogMetadataManager.listRemoteLogSegments(new 
TopicIdPartition(topicId, tp), epoch).asScala
         .foreach(rlsMetadata =>
           if (rlsMetadata.maxTimestampMs() >= timestamp && 
rlsMetadata.endOffset() >= startingOffset) {
diff --git a/core/src/main/scala/kafka/server/LocalLeaderEndPoint.scala 
b/core/src/main/scala/kafka/server/LocalLeaderEndPoint.scala
index 6702e180f47..dd6fffd2a81 100644
--- a/core/src/main/scala/kafka/server/LocalLeaderEndPoint.scala
+++ b/core/src/main/scala/kafka/server/LocalLeaderEndPoint.scala
@@ -117,21 +117,21 @@ class LocalLeaderEndPoint(sourceBroker: BrokerEndPoint,
     val partition = replicaManager.getPartitionOrException(topicPartition)
     val logStartOffset = partition.localLogOrException.logStartOffset
     val epoch = 
partition.localLogOrException.leaderEpochCache.get.epochForOffset(logStartOffset)
-    (epoch.getOrElse(0), logStartOffset)
+    (epoch.orElse(0), logStartOffset)
   }
 
   override def fetchLatestOffset(topicPartition: TopicPartition, 
currentLeaderEpoch: Int): (Int, Long) = {
     val partition = replicaManager.getPartitionOrException(topicPartition)
     val logEndOffset = partition.localLogOrException.logEndOffset
     val epoch = 
partition.localLogOrException.leaderEpochCache.get.epochForOffset(logEndOffset)
-    (epoch.getOrElse(0), logEndOffset)
+    (epoch.orElse(0), logEndOffset)
   }
 
   override def fetchEarliestLocalOffset(topicPartition: TopicPartition, 
currentLeaderEpoch: Int): (Int, Long) = {
     val partition = replicaManager.getPartitionOrException(topicPartition)
     val localLogStartOffset = 
partition.localLogOrException.localLogStartOffset()
     val epoch = 
partition.localLogOrException.leaderEpochCache.get.epochForOffset(localLogStartOffset)
-    (epoch.getOrElse(0), localLogStartOffset)
+    (epoch.orElse(0), localLogStartOffset)
   }
 
   override def fetchEpochEndOffsets(partitions: collection.Map[TopicPartition, 
EpochData]): Map[TopicPartition, EpochEndOffset] = {
diff --git a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala 
b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
index d96d9136924..852ab6bce8b 100644
--- a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
+++ b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
@@ -19,8 +19,6 @@ package kafka.server
 
 import kafka.log.remote.RemoteLogManager
 import kafka.log.{LeaderOffsetIncremented, LogAppendInfo, UnifiedLog}
-import kafka.server.checkpoints.LeaderEpochCheckpointFile
-import kafka.server.epoch.EpochEntry
 import 
org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset
 import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.record.MemoryRecords
@@ -29,13 +27,14 @@ import org.apache.kafka.common.utils.Utils
 import org.apache.kafka.common.{KafkaException, TopicPartition}
 import org.apache.kafka.server.common.CheckpointFile.CheckpointReadBuffer
 import org.apache.kafka.server.common.MetadataVersion
+import org.apache.kafka.server.log.internals.EpochEntry
 import org.apache.kafka.server.log.remote.storage.{RemoteLogSegmentMetadata, 
RemoteStorageException, RemoteStorageManager}
+import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpointFile
 
 import java.io.{BufferedReader, File, InputStreamReader}
 import java.nio.charset.StandardCharsets
 import java.nio.file.{Files, StandardCopyOption}
 import scala.collection.mutable
-import scala.jdk.CollectionConverters._
 
 class ReplicaFetcherThread(name: String,
                            leader: LeaderEndPoint,
@@ -329,12 +328,12 @@ class ReplicaFetcherThread(name: String,
     nextOffset
   }
 
-  private def readLeaderEpochCheckpoint(rlm: RemoteLogManager, 
remoteLogSegmentMetadata: RemoteLogSegmentMetadata): collection.Seq[EpochEntry] 
= {
+  private def readLeaderEpochCheckpoint(rlm: RemoteLogManager, 
remoteLogSegmentMetadata: RemoteLogSegmentMetadata): java.util.List[EpochEntry] 
= {
     val inputStream = 
rlm.storageManager().fetchIndex(remoteLogSegmentMetadata, 
RemoteStorageManager.IndexType.LEADER_EPOCH)
     val bufferedReader = new BufferedReader(new InputStreamReader(inputStream, 
StandardCharsets.UTF_8))
     try {
-      val readBuffer = new CheckpointReadBuffer[EpochEntry]("", 
bufferedReader,  0, LeaderEpochCheckpointFile.Formatter)
-      readBuffer.read().asScala.toSeq
+      val readBuffer = new CheckpointReadBuffer[EpochEntry]("", 
bufferedReader,  0, LeaderEpochCheckpointFile.FORMATTER)
+      readBuffer.read()
     } finally {
       bufferedReader.close()
     }
diff --git 
a/core/src/main/scala/kafka/server/checkpoints/CheckpointFileWithFailureHandler.scala
 
b/core/src/main/scala/kafka/server/checkpoints/CheckpointFileWithFailureHandler.scala
deleted file mode 100644
index 0e669249bdd..00000000000
--- 
a/core/src/main/scala/kafka/server/checkpoints/CheckpointFileWithFailureHandler.scala
+++ /dev/null
@@ -1,56 +0,0 @@
-/**
-  * Licensed to the Apache Software Foundation (ASF) under one or more
-  * contributor license agreements.  See the NOTICE file distributed with
-  * this work for additional information regarding copyright ownership.
-  * The ASF licenses this file to You under the Apache License, Version 2.0
-  * (the "License"); you may not use this file except in compliance with
-  * the License.  You may obtain a copy of the License at
-  *
-  * http://www.apache.org/licenses/LICENSE-2.0
-  *
-  * Unless required by applicable law or agreed to in writing, software
-  * distributed under the License is distributed on an "AS IS" BASIS,
-  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-  * See the License for the specific language governing permissions and
-  * limitations under the License.
-  */
-package kafka.server.checkpoints
-
-import org.apache.kafka.common.errors.KafkaStorageException
-import org.apache.kafka.server.common.CheckpointFile
-import org.apache.kafka.server.log.internals.LogDirFailureChannel
-import CheckpointFile.EntryFormatter
-
-import java.io._
-import scala.collection.Seq
-import scala.jdk.CollectionConverters._
-
-class CheckpointFileWithFailureHandler[T](val file: File,
-                                          version: Int,
-                                          formatter: EntryFormatter[T],
-                                          logDirFailureChannel: 
LogDirFailureChannel,
-                                          logDir: String) {
-  private val checkpointFile = new CheckpointFile[T](file, version, formatter)
-
-  def write(entries: Iterable[T]): Unit = {
-      try {
-        checkpointFile.write(entries.toSeq.asJava)
-      } catch {
-        case e: IOException =>
-          val msg = s"Error while writing to checkpoint file 
${file.getAbsolutePath}"
-          logDirFailureChannel.maybeAddOfflineLogDir(logDir, msg, e)
-          throw new KafkaStorageException(msg, e)
-      }
-  }
-
-  def read(): Seq[T] = {
-      try {
-        checkpointFile.read().asScala
-      } catch {
-        case e: IOException =>
-          val msg = s"Error while reading checkpoint file 
${file.getAbsolutePath}"
-          logDirFailureChannel.maybeAddOfflineLogDir(logDir, msg, e)
-          throw new KafkaStorageException(msg, e)
-      }
-  }
-}
diff --git 
a/core/src/main/scala/kafka/server/checkpoints/LeaderEpochCheckpointFile.scala 
b/core/src/main/scala/kafka/server/checkpoints/LeaderEpochCheckpointFile.scala
deleted file mode 100644
index 93ef93d0bd3..00000000000
--- 
a/core/src/main/scala/kafka/server/checkpoints/LeaderEpochCheckpointFile.scala
+++ /dev/null
@@ -1,74 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package kafka.server.checkpoints
-
-import kafka.server.epoch.EpochEntry
-import org.apache.kafka.server.common.CheckpointFile.EntryFormatter
-import org.apache.kafka.server.log.internals.LogDirFailureChannel
-
-import java.io._
-import java.util.Optional
-import java.util.regex.Pattern
-import scala.collection._
-
-trait LeaderEpochCheckpoint {
-  def write(epochs: Iterable[EpochEntry]): Unit
-  def read(): Seq[EpochEntry]
-}
-
-object LeaderEpochCheckpointFile {
-  private val LeaderEpochCheckpointFilename = "leader-epoch-checkpoint"
-  private val WhiteSpacesPattern = Pattern.compile("\\s+")
-  private val CurrentVersion = 0
-
-  def newFile(dir: File): File = new File(dir, LeaderEpochCheckpointFilename)
-
-  object Formatter extends EntryFormatter[EpochEntry] {
-
-    override def toString(entry: EpochEntry): String = s"${entry.epoch} 
${entry.startOffset}"
-
-    override def fromString(line: String): Optional[EpochEntry] = {
-      WhiteSpacesPattern.split(line) match {
-        case Array(epoch, offset) =>
-          Optional.of(EpochEntry(epoch.toInt, offset.toLong))
-        case _ => Optional.empty()
-      }
-    }
-
-  }
-}
-
-/**
- * This class persists a map of (LeaderEpoch => Offsets) to a file (for a 
certain replica)
- *
- * The format in the LeaderEpoch checkpoint file is like this:
- * -----checkpoint file begin------
- * 0                <- LeaderEpochCheckpointFile.currentVersion
- * 2                <- following entries size
- * 0  1     <- the format is: leader_epoch(int32) start_offset(int64)
- * 1  2
- * -----checkpoint file end----------
- */
-class LeaderEpochCheckpointFile(val file: File, logDirFailureChannel: 
LogDirFailureChannel = null) extends LeaderEpochCheckpoint {
-  import LeaderEpochCheckpointFile._
-
-  val checkpoint = new CheckpointFileWithFailureHandler[EpochEntry](file, 
CurrentVersion, Formatter, logDirFailureChannel, file.getParentFile.getParent)
-
-  def write(epochs: Iterable[EpochEntry]): Unit = checkpoint.write(epochs)
-
-  def read(): Seq[EpochEntry] = checkpoint.read()
-}
diff --git 
a/core/src/main/scala/kafka/server/checkpoints/OffsetCheckpointFile.scala 
b/core/src/main/scala/kafka/server/checkpoints/OffsetCheckpointFile.scala
index 483a186c8dd..c2a784b3a8a 100644
--- a/core/src/main/scala/kafka/server/checkpoints/OffsetCheckpointFile.scala
+++ b/core/src/main/scala/kafka/server/checkpoints/OffsetCheckpointFile.scala
@@ -16,10 +16,10 @@
   */
 package kafka.server.checkpoints
 
-import kafka.server.epoch.EpochEntry
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.server.common.CheckpointFile.EntryFormatter
-import org.apache.kafka.server.log.internals.LogDirFailureChannel
+import org.apache.kafka.server.log.internals.{EpochEntry, LogDirFailureChannel}
+import 
org.apache.kafka.storage.internals.checkpoint.CheckpointFileWithFailureHandler
 
 import java.io._
 import java.util.Optional
@@ -65,9 +65,19 @@ class OffsetCheckpointFile(val file: File, 
logDirFailureChannel: LogDirFailureCh
   val checkpoint = new CheckpointFileWithFailureHandler[(TopicPartition, 
Long)](file, OffsetCheckpointFile.CurrentVersion,
     OffsetCheckpointFile.Formatter, logDirFailureChannel, file.getParent)
 
-  def write(offsets: Map[TopicPartition, Long]): Unit = 
checkpoint.write(offsets)
+  def write(offsets: Map[TopicPartition, Long]): Unit = {
+    val list: java.util.List[(TopicPartition, Long)] = new 
java.util.ArrayList[(TopicPartition, Long)](offsets.size)
+    offsets.foreach(x => list.add(x))
+    checkpoint.write(list)
+  }
 
-  def read(): Map[TopicPartition, Long] = checkpoint.read().toMap
+  def read(): Map[TopicPartition, Long] = {
+    val list = checkpoint.read()
+    val result = mutable.Map.empty[TopicPartition, Long]
+    result.sizeHint(list.size())
+    list.forEach { case (tp, offset) => result(tp) = offset }
+    result
+  }
 
 }
 
diff --git a/core/src/main/scala/kafka/server/epoch/LeaderEpochFileCache.scala 
b/core/src/main/scala/kafka/server/epoch/LeaderEpochFileCache.scala
deleted file mode 100644
index 2053d8c50dd..00000000000
--- a/core/src/main/scala/kafka/server/epoch/LeaderEpochFileCache.scala
+++ /dev/null
@@ -1,343 +0,0 @@
-/**
-  * Licensed to the Apache Software Foundation (ASF) under one or more
-  * contributor license agreements.  See the NOTICE file distributed with
-  * this work for additional information regarding copyright ownership.
-  * The ASF licenses this file to You under the Apache License, Version 2.0
-  * (the "License"); you may not use this file except in compliance with
-  * the License.  You may obtain a copy of the License at
-  *
-  * http://www.apache.org/licenses/LICENSE-2.0
-  *
-  * Unless required by applicable law or agreed to in writing, software
-  * distributed under the License is distributed on an "AS IS" BASIS,
-  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-  * See the License for the specific language governing permissions and
-  * limitations under the License.
-  */
-package kafka.server.epoch
-
-import java.util
-import java.util.concurrent.locks.ReentrantReadWriteLock
-
-import kafka.server.checkpoints.LeaderEpochCheckpoint
-import kafka.utils.CoreUtils._
-import kafka.utils.Logging
-import org.apache.kafka.common.TopicPartition
-import 
org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.{UNDEFINED_EPOCH,
 UNDEFINED_EPOCH_OFFSET}
-
-import scala.collection.{Seq, mutable}
-import scala.jdk.CollectionConverters._
-
-/**
- * Represents a cache of (LeaderEpoch => Offset) mappings for a particular 
replica.
- *
- * Leader Epoch = epoch assigned to each leader by the controller.
- * Offset = offset of the first message in each epoch.
- *
- * @param topicPartition the associated topic partition
- * @param checkpoint the checkpoint file
- */
-class LeaderEpochFileCache(topicPartition: TopicPartition,
-                           checkpoint: LeaderEpochCheckpoint) extends Logging {
-  this.logIdent = s"[LeaderEpochCache $topicPartition] "
-
-  private val lock = new ReentrantReadWriteLock()
-  private val epochs = new util.TreeMap[Int, EpochEntry]()
-
-  inWriteLock(lock) {
-    checkpoint.read().foreach(assign)
-  }
-
-  /**
-    * Assigns the supplied Leader Epoch to the supplied Offset
-    * Once the epoch is assigned it cannot be reassigned
-    */
-  def assign(epoch: Int, startOffset: Long): Unit = {
-    val entry = EpochEntry(epoch, startOffset)
-    if (assign(entry)) {
-      debug(s"Appended new epoch entry $entry. Cache now contains 
${epochs.size} entries.")
-      flush()
-    }
-  }
-
-  def assign(entries: Seq[EpochEntry]): Unit = {
-    entries.foreach(entry =>
-      if (assign(entry)) {
-        debug(s"Appended new epoch entry $entry. Cache now contains 
${epochs.size} entries.")
-      })
-    flush()
-  }
-
-  private def assign(entry: EpochEntry): Boolean = {
-    if (entry.epoch < 0 || entry.startOffset < 0) {
-      throw new IllegalArgumentException(s"Received invalid partition leader 
epoch entry $entry")
-    }
-
-    def isUpdateNeeded: Boolean = {
-      latestEntry match {
-        case Some(lastEntry) =>
-          entry.epoch != lastEntry.epoch || entry.startOffset < 
lastEntry.startOffset
-        case None =>
-          true
-      }
-    }
-
-    // Check whether the append is needed before acquiring the write lock
-    // in order to avoid contention with readers in the common case
-    if (!isUpdateNeeded)
-      return false
-
-    inWriteLock(lock) {
-      if (isUpdateNeeded) {
-        maybeTruncateNonMonotonicEntries(entry)
-        epochs.put(entry.epoch, entry)
-        true
-      } else {
-        false
-      }
-    }
-  }
-
-  /**
-   * Remove any entries which violate monotonicity prior to appending a new 
entry
-   */
-  private def maybeTruncateNonMonotonicEntries(newEntry: EpochEntry): Unit = {
-    val removedEpochs = removeFromEnd { entry =>
-      entry.epoch >= newEntry.epoch || entry.startOffset >= 
newEntry.startOffset
-    }
-
-    if (removedEpochs.size > 1
-      || (removedEpochs.nonEmpty && removedEpochs.head.startOffset != 
newEntry.startOffset)) {
-
-      // Only log a warning if there were non-trivial removals. If the start 
offset of the new entry
-      // matches the start offset of the removed epoch, then no data has been 
written and the truncation
-      // is expected.
-      warn(s"New epoch entry $newEntry caused truncation of conflicting 
entries $removedEpochs. " +
-        s"Cache now contains ${epochs.size} entries.")
-    }
-  }
-
-  private def removeFromEnd(predicate: EpochEntry => Boolean): Seq[EpochEntry] 
= {
-    removeWhileMatching(epochs.descendingMap.entrySet().iterator(), predicate)
-  }
-
-  private def removeFromStart(predicate: EpochEntry => Boolean): 
Seq[EpochEntry] = {
-    removeWhileMatching(epochs.entrySet().iterator(), predicate)
-  }
-
-  private def removeWhileMatching(
-    iterator: util.Iterator[util.Map.Entry[Int, EpochEntry]],
-    predicate: EpochEntry => Boolean
-  ): Seq[EpochEntry] = {
-    val removedEpochs = mutable.ListBuffer.empty[EpochEntry]
-
-    while (iterator.hasNext) {
-      val entry = iterator.next().getValue
-      if (predicate.apply(entry)) {
-        removedEpochs += entry
-        iterator.remove()
-      } else {
-        return removedEpochs
-      }
-    }
-
-    removedEpochs
-  }
-
-  def nonEmpty: Boolean = inReadLock(lock) {
-    !epochs.isEmpty
-  }
-
-  def latestEntry: Option[EpochEntry] = {
-    inReadLock(lock) {
-      Option(epochs.lastEntry).map(_.getValue)
-    }
-  }
-
-  /**
-   * Returns the current Leader Epoch if one exists. This is the latest epoch
-   * which has messages assigned to it.
-   */
-  def latestEpoch: Option[Int] = {
-    latestEntry.map(_.epoch)
-  }
-
-  def previousEpoch: Option[Int] = {
-    inReadLock(lock) {
-      latestEntry.flatMap(entry => 
Option(epochs.lowerEntry(entry.epoch))).map(_.getKey)
-    }
-  }
-
-  /**
-   * Get the earliest cached entry if one exists.
-   */
-  def earliestEntry: Option[EpochEntry] = {
-    inReadLock(lock) {
-      Option(epochs.firstEntry).map(_.getValue)
-    }
-  }
-
-  def previousEpoch(epoch: Int): Option[Int] = {
-    inReadLock(lock) {
-      Option(epochs.lowerKey(epoch))
-    }
-  }
-
-  def nextEpoch(epoch: Int): Option[Int] = {
-    inReadLock(lock) {
-      Option(epochs.higherKey(epoch))
-    }
-  }
-
-  def epochEntry(epoch: Int): Option[EpochEntry] = {
-    inReadLock(lock) {
-      Option.apply(epochs.get(epoch))
-    }
-  }
-
-  /**
-    * Returns the Leader Epoch and the End Offset for a requested Leader Epoch.
-    *
-    * The Leader Epoch returned is the largest epoch less than or equal to the 
requested Leader
-    * Epoch. The End Offset is the end offset of this epoch, which is defined 
as the start offset
-    * of the first Leader Epoch larger than the Leader Epoch requested, or 
else the Log End
-    * Offset if the latest epoch was requested.
-    *
-    * During the upgrade phase, where there are existing messages may not have 
a leader epoch,
-    * if requestedEpoch is < the first epoch cached, UNDEFINED_EPOCH_OFFSET 
will be returned
-    * so that the follower falls back to High Water Mark.
-    *
-    * @param requestedEpoch requested leader epoch
-    * @param logEndOffset the existing Log End Offset
-    * @return found leader epoch and end offset
-    */
-  def endOffsetFor(requestedEpoch: Int, logEndOffset: Long): (Int, Long) = {
-    inReadLock(lock) {
-      val epochAndOffset =
-        if (requestedEpoch == UNDEFINED_EPOCH) {
-          // This may happen if a bootstrapping follower sends a request with 
undefined epoch or
-          // a follower is on the older message format where leader epochs are 
not recorded
-          (UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET)
-        } else if (latestEpoch.contains(requestedEpoch)) {
-          // For the leader, the latest epoch is always the current leader 
epoch that is still being written to.
-          // Followers should not have any reason to query for the end offset 
of the current epoch, but a consumer
-          // might if it is verifying its committed offset following a group 
rebalance. In this case, we return
-          // the current log end offset which makes the truncation check work 
as expected.
-          (requestedEpoch, logEndOffset)
-        } else {
-          val higherEntry = epochs.higherEntry(requestedEpoch)
-          if (higherEntry == null) {
-            // The requested epoch is larger than any known epoch. This case 
should never be hit because
-            // the latest cached epoch is always the largest.
-            (UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET)
-          } else {
-            val floorEntry = epochs.floorEntry(requestedEpoch)
-            if (floorEntry == null) {
-              // The requested epoch is smaller than any known epoch, so we 
return the start offset of the first
-              // known epoch which is larger than it. This may be inaccurate 
as there could have been
-              // epochs in between, but the point is that the data has already 
been removed from the log
-              // and we want to ensure that the follower can replicate 
correctly beginning from the leader's
-              // start offset.
-              (requestedEpoch, higherEntry.getValue.startOffset)
-            } else {
-              // We have at least one previous epoch and one subsequent epoch. 
The result is the first
-              // prior epoch and the starting offset of the first subsequent 
epoch.
-              (floorEntry.getValue.epoch, higherEntry.getValue.startOffset)
-            }
-          }
-        }
-      trace(s"Processed end offset request for epoch $requestedEpoch and 
returning epoch ${epochAndOffset._1} " +
-        s"with end offset ${epochAndOffset._2} from epoch cache of size 
${epochs.size}")
-      epochAndOffset
-    }
-  }
-
-  /**
-    * Removes all epoch entries from the store with start offsets greater than 
or equal to the passed offset.
-    */
-  def truncateFromEnd(endOffset: Long): Unit = {
-    inWriteLock(lock) {
-      if (endOffset >= 0 && latestEntry.exists(_.startOffset >= endOffset)) {
-        val removedEntries = removeFromEnd(_.startOffset >= endOffset)
-
-        flush()
-
-        debug(s"Cleared entries $removedEntries from epoch cache after " +
-          s"truncating to end offset $endOffset, leaving ${epochs.size} 
entries in the cache.")
-      }
-    }
-  }
-
-  /**
-    * Clears old epoch entries. This method searches for the oldest epoch < 
offset, updates the saved epoch offset to
-    * be offset, then clears any previous epoch entries.
-    *
-    * This method is exclusive: so truncateFromStart(6) will retain an entry 
at offset 6.
-    *
-    * @param startOffset the offset to clear up to
-    */
-  def truncateFromStart(startOffset: Long): Unit = {
-    inWriteLock(lock) {
-      val removedEntries = removeFromStart { entry =>
-        entry.startOffset <= startOffset
-      }
-
-      removedEntries.lastOption.foreach { firstBeforeStartOffset =>
-        val updatedFirstEntry = EpochEntry(firstBeforeStartOffset.epoch, 
startOffset)
-        epochs.put(updatedFirstEntry.epoch, updatedFirstEntry)
-
-        flush()
-
-        debug(s"Cleared entries $removedEntries and rewrote first entry 
$updatedFirstEntry after " +
-          s"truncating to start offset $startOffset, leaving ${epochs.size} in 
the cache.")
-      }
-    }
-  }
-
-  def epochForOffset(offset: Long): Option[Int] = {
-    inReadLock(lock) {
-      var previousEpoch: Option[Int] = None
-      epochs.values().asScala.foreach {
-        case EpochEntry(epoch, startOffset) =>
-          if (startOffset == offset)
-            return Some(epoch)
-          if (startOffset > offset)
-            return previousEpoch
-
-          previousEpoch = Some(epoch)
-      }
-      previousEpoch
-    }
-  }
-
-  /**
-    * Delete all entries.
-    */
-  def clearAndFlush(): Unit = {
-    inWriteLock(lock) {
-      epochs.clear()
-      flush()
-    }
-  }
-
-  def clear(): Unit = {
-    inWriteLock(lock) {
-      epochs.clear()
-    }
-  }
-
-  // Visible for testing
-  def epochEntries: Seq[EpochEntry] = epochs.values.asScala.toSeq
-
-  def flush(): Unit = {
-    checkpoint.write(epochs.values.asScala)
-  }
-
-}
-
-// Mapping of epoch to the first offset of the subsequent epoch
-case class EpochEntry(epoch: Int, startOffset: Long) {
-  override def toString: String = {
-    s"EpochEntry(epoch=$epoch, startOffset=$startOffset)"
-  }
-}
diff --git a/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala 
b/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala
index 942e665ef06..6c6bc3c713a 100644
--- a/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala
+++ b/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala
@@ -24,7 +24,6 @@ import kafka.api.LeaderAndIsr
 import kafka.log._
 import kafka.server._
 import kafka.server.checkpoints.OffsetCheckpoints
-import kafka.server.epoch.LeaderEpochFileCache
 import kafka.server.metadata.MockConfigRepository
 import kafka.utils._
 import org.apache.kafka.common.TopicIdPartition
@@ -37,6 +36,7 @@ import org.apache.kafka.common.utils.Utils
 import org.apache.kafka.common.{TopicPartition, Uuid}
 import org.apache.kafka.server.common.MetadataVersion
 import org.apache.kafka.server.log.internals.{AppendOrigin, CleanerConfig, 
FetchIsolation, FetchParams, LogConfig, LogDirFailureChannel}
+import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache
 import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertTrue}
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
 import org.mockito.ArgumentMatchers
diff --git a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala 
b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
index 55c59f82207..078bc8f943d 100644
--- a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
+++ b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
@@ -22,7 +22,6 @@ import kafka.common.UnexpectedAppendOffsetException
 import kafka.log._
 import kafka.server._
 import kafka.server.checkpoints.OffsetCheckpoints
-import kafka.server.epoch.EpochEntry
 import kafka.utils._
 import kafka.zk.KafkaZkClient
 import org.apache.kafka.common.errors.{ApiException, 
FencedLeaderEpochException, InconsistentTopicIdException, 
NotLeaderOrFollowerException, OffsetNotAvailableException, 
OffsetOutOfRangeException, UnknownLeaderEpochException}
@@ -45,7 +44,6 @@ import org.mockito.invocation.InvocationOnMock
 import java.nio.ByteBuffer
 import java.util.Optional
 import java.util.concurrent.{CountDownLatch, Semaphore}
-import kafka.server.epoch.LeaderEpochFileCache
 import kafka.server.metadata.{KRaftMetadataCache, ZkMetadataCache}
 import org.apache.kafka.clients.ClientResponse
 import org.apache.kafka.common.network.ListenerName
@@ -54,9 +52,10 @@ import 
org.apache.kafka.common.replica.ClientMetadata.DefaultClientMetadata
 import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol}
 import org.apache.kafka.server.common.MetadataVersion
 import org.apache.kafka.server.common.MetadataVersion.IBP_2_6_IV0
-import org.apache.kafka.server.log.internals.{AppendOrigin, CleanerConfig, 
FetchIsolation, FetchParams, LogDirFailureChannel}
+import org.apache.kafka.server.log.internals.{AppendOrigin, CleanerConfig, 
EpochEntry, FetchIsolation, FetchParams, LogDirFailureChannel}
 import org.apache.kafka.server.metrics.KafkaYammerMetrics
 import org.apache.kafka.server.util.KafkaScheduler
+import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache
 import org.junit.jupiter.params.ParameterizedTest
 import org.junit.jupiter.params.provider.ValueSource
 
@@ -2651,7 +2650,7 @@ class PartitionTest extends AbstractPartitionTest {
     assertEquals(Some(0L), partition.leaderEpochStartOffsetOpt)
 
     val leaderLog = partition.localLogOrException
-    assertEquals(Some(EpochEntry(leaderEpoch, 0L)), 
leaderLog.leaderEpochCache.flatMap(_.latestEntry))
+    assertEquals(Optional.of(new EpochEntry(leaderEpoch, 0L)), 
leaderLog.leaderEpochCache.asJava.flatMap(_.latestEntry))
 
     // Write to the log to increment the log end offset.
     leaderLog.appendAsLeader(MemoryRecords.withRecords(0L, 
CompressionType.NONE, 0,
@@ -2675,7 +2674,7 @@ class PartitionTest extends AbstractPartitionTest {
     assertEquals(leaderEpoch, partition.getLeaderEpoch)
     assertEquals(Set(leaderId), partition.partitionState.isr)
     assertEquals(Some(0L), partition.leaderEpochStartOffsetOpt)
-    assertEquals(Some(EpochEntry(leaderEpoch, 0L)), 
leaderLog.leaderEpochCache.flatMap(_.latestEntry))
+    assertEquals(Optional.of(new EpochEntry(leaderEpoch, 0L)), 
leaderLog.leaderEpochCache.asJava.flatMap(_.latestEntry))
   }
 
   @Test
diff --git a/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala 
b/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
index 464af9c6aa1..3c49edfde16 100644
--- a/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
@@ -21,7 +21,6 @@ import java.io.{BufferedWriter, File, FileWriter, IOException}
 import java.nio.ByteBuffer
 import java.nio.file.{Files, NoSuchFileException, Paths}
 import java.util.Properties
-import kafka.server.epoch.{EpochEntry, LeaderEpochFileCache}
 import kafka.server.{BrokerTopicStats, KafkaConfig}
 import kafka.server.metadata.MockConfigRepository
 import kafka.utils.{MockTime, TestUtils}
@@ -32,8 +31,9 @@ import org.apache.kafka.common.record.{CompressionType, 
ControlRecordType, Defau
 import org.apache.kafka.common.utils.{Time, Utils}
 import org.apache.kafka.server.common.MetadataVersion
 import org.apache.kafka.server.common.MetadataVersion.IBP_0_11_0_IV0
-import org.apache.kafka.server.log.internals.{AbortedTxn, CleanerConfig, 
FetchDataInfo, LogConfig, LogDirFailureChannel, OffsetIndex, SnapshotFile}
+import org.apache.kafka.server.log.internals.{AbortedTxn, CleanerConfig, 
EpochEntry, FetchDataInfo, LogConfig, LogDirFailureChannel, OffsetIndex, 
SnapshotFile}
 import org.apache.kafka.server.util.Scheduler
+import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache
 import org.junit.jupiter.api.Assertions.{assertDoesNotThrow, assertEquals, 
assertFalse, assertNotEquals, assertThrows, assertTrue}
 import org.junit.jupiter.api.function.Executable
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
@@ -1374,11 +1374,11 @@ class LogLoaderTest {
     val fourthBatch = singletonRecordsWithLeaderEpoch(value = 
"random".getBytes, leaderEpoch = 3, offset = 3)
     log.appendAsFollower(records = fourthBatch)
 
-    assertEquals(ListBuffer(EpochEntry(1, 0), EpochEntry(2, 1), EpochEntry(3, 
3)), leaderEpochCache.epochEntries)
+    assertEquals(java.util.Arrays.asList(new EpochEntry(1, 0), new 
EpochEntry(2, 1), new EpochEntry(3, 3)), leaderEpochCache.epochEntries)
 
     // deliberately remove some of the epoch entries
     leaderEpochCache.truncateFromEnd(2)
-    assertNotEquals(ListBuffer(EpochEntry(1, 0), EpochEntry(2, 1), 
EpochEntry(3, 3)), leaderEpochCache.epochEntries)
+    assertNotEquals(java.util.Arrays.asList(new EpochEntry(1, 0), new 
EpochEntry(2, 1), new EpochEntry(3, 3)), leaderEpochCache.epochEntries)
     log.close()
 
     // reopen the log and recover from the beginning
@@ -1386,7 +1386,7 @@ class LogLoaderTest {
     val recoveredLeaderEpochCache = recoveredLog.leaderEpochCache.get
 
     // epoch entries should be recovered
-    assertEquals(ListBuffer(EpochEntry(1, 0), EpochEntry(2, 1), EpochEntry(3, 
3)), recoveredLeaderEpochCache.epochEntries)
+    assertEquals(java.util.Arrays.asList(new EpochEntry(1, 0), new 
EpochEntry(2, 1), new EpochEntry(3, 3)), recoveredLeaderEpochCache.epochEntries)
     recoveredLog.close()
   }
 
diff --git a/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala 
b/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala
index 9bd5c56e7ac..bad5e280b8e 100644
--- a/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala
@@ -18,21 +18,20 @@ package kafka.log
 
 import java.io.File
 import java.util.OptionalLong
-
-import kafka.server.checkpoints.LeaderEpochCheckpoint
-import kafka.server.epoch.{EpochEntry, LeaderEpochFileCache}
 import kafka.utils.TestUtils
 import kafka.utils.TestUtils.checkEquals
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.config.TopicConfig
 import org.apache.kafka.common.record._
 import org.apache.kafka.common.utils.{MockTime, Time, Utils}
-import org.apache.kafka.server.log.internals.{BatchMetadata, LogConfig, 
ProducerStateEntry}
+import org.apache.kafka.server.log.internals.{BatchMetadata, EpochEntry, 
LogConfig, ProducerStateEntry}
+import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpoint
+import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
 
+import java.util
 import scala.collection._
-import scala.collection.mutable.ArrayBuffer
 import scala.jdk.CollectionConverters._
 
 class LogSegmentTest {
@@ -381,11 +380,11 @@ class LogSegmentTest {
     val checkpoint: LeaderEpochCheckpoint = new LeaderEpochCheckpoint {
       private var epochs = Seq.empty[EpochEntry]
 
-      override def write(epochs: Iterable[EpochEntry]): Unit = {
-        this.epochs = epochs.toVector
+      override def write(epochs: util.Collection[EpochEntry]): Unit = {
+        this.epochs = epochs.asScala.toSeq
       }
 
-      override def read(): Seq[EpochEntry] = this.epochs
+      override def read(): java.util.List[EpochEntry] = this.epochs.asJava
     }
 
     val cache = new LeaderEpochFileCache(topicPartition, checkpoint)
@@ -406,9 +405,9 @@ class LogSegmentTest {
         new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes)))
 
     seg.recover(newProducerStateManager(), Some(cache))
-    assertEquals(ArrayBuffer(EpochEntry(epoch = 0, startOffset = 104L),
-                             EpochEntry(epoch = 1, startOffset = 106),
-                             EpochEntry(epoch = 2, startOffset = 110)),
+    assertEquals(java.util.Arrays.asList(new EpochEntry(0, 104L),
+                             new EpochEntry(1, 106),
+                             new EpochEntry(2, 110)),
       cache.epochEntries)
   }
 
diff --git a/core/src/test/scala/unit/kafka/log/LogTestUtils.scala 
b/core/src/test/scala/unit/kafka/log/LogTestUtils.scala
index a3cac4c27d3..bd9c8d7ca53 100644
--- a/core/src/test/scala/unit/kafka/log/LogTestUtils.scala
+++ b/core/src/test/scala/unit/kafka/log/LogTestUtils.scala
@@ -21,13 +21,12 @@ import kafka.log.remote.RemoteLogManager
 
 import java.io.File
 import java.util.Properties
-import kafka.server.checkpoints.LeaderEpochCheckpointFile
 import kafka.server.BrokerTopicStats
 import kafka.utils.TestUtils
 import org.apache.kafka.common.Uuid
 import org.apache.kafka.common.record.{CompressionType, ControlRecordType, 
EndTransactionMarker, FileRecords, MemoryRecords, RecordBatch, SimpleRecord}
 import org.apache.kafka.common.utils.{Time, Utils}
-import org.apache.kafka.server.log.internals.{AbortedTxn, AppendOrigin, 
FetchDataInfo, FetchIsolation, LazyIndex, LogConfig, LogDirFailureChannel, 
TransactionIndex}
+import org.apache.kafka.server.log.internals.{AbortedTxn, AppendOrigin,  
FetchDataInfo, FetchIsolation, LazyIndex, LogConfig, LogDirFailureChannel, 
TransactionIndex}
 import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse}
 
 import java.nio.file.Files
@@ -35,8 +34,8 @@ import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
 import kafka.log
 import org.apache.kafka.common.config.TopicConfig
 import org.apache.kafka.server.util.Scheduler
+import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpointFile
 
-import scala.collection.Iterable
 import scala.jdk.CollectionConverters._
 
 object LogTestUtils {
diff --git a/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala 
b/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
index f86102d0b31..a3f8404e2f4 100755
--- a/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
+++ b/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
@@ -17,15 +17,8 @@
 
 package kafka.log
 
-import java.io._
-import java.nio.ByteBuffer
-import java.nio.file.Files
-import java.util.concurrent.{Callable, ConcurrentHashMap, Executors}
-import java.util.{Optional, Properties}
 import kafka.common.{OffsetsOutOfOrderException, 
UnexpectedAppendOffsetException}
 import kafka.log.remote.RemoteLogManager
-import kafka.server.checkpoints.LeaderEpochCheckpointFile
-import kafka.server.epoch.{EpochEntry, LeaderEpochFileCache}
 import kafka.server.{BrokerTopicStats, KafkaConfig, PartitionMetadataFile}
 import kafka.utils._
 import org.apache.kafka.common.config.TopicConfig
@@ -38,20 +31,28 @@ import 
org.apache.kafka.common.record.MemoryRecords.RecordFilter
 import org.apache.kafka.common.record._
 import org.apache.kafka.common.requests.{ListOffsetsRequest, 
ListOffsetsResponse}
 import org.apache.kafka.common.utils.{BufferSupplier, Time, Utils}
-import org.apache.kafka.server.log.internals.{AbortedTxn, AppendOrigin, 
FetchIsolation, LogConfig, LogOffsetMetadata, RecordValidationException}
+import org.apache.kafka.server.log.internals.{AbortedTxn, AppendOrigin, 
EpochEntry, FetchIsolation, LogConfig, LogOffsetMetadata, 
RecordValidationException}
 import 
org.apache.kafka.server.log.remote.metadata.storage.TopicBasedRemoteLogMetadataManagerConfig
 import org.apache.kafka.server.metrics.KafkaYammerMetrics
 import org.apache.kafka.server.util.{KafkaScheduler, Scheduler}
+import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpointFile
+import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
 import org.mockito.ArgumentMatchers
 import org.mockito.ArgumentMatchers.anyLong
 import org.mockito.Mockito.{mock, when}
 
+import java.io._
+import java.nio.ByteBuffer
+import java.nio.file.Files
+import java.util.concurrent.{Callable, ConcurrentHashMap, Executors}
+import java.util.{Optional, Properties}
 import scala.annotation.nowarn
 import scala.collection.Map
-import scala.jdk.CollectionConverters._
 import scala.collection.mutable.ListBuffer
+import scala.compat.java8.OptionConverters._
+import scala.jdk.CollectionConverters._
 
 class UnifiedLogTest {
   var config: KafkaConfig = _
@@ -596,23 +597,23 @@ class UnifiedLogTest {
     val records = TestUtils.records(List(new SimpleRecord("a".getBytes), new 
SimpleRecord("b".getBytes)),
       baseOffset = 27)
     appendAsFollower(log, records, leaderEpoch = 19)
-    assertEquals(Some(EpochEntry(epoch = 19, startOffset = 27)),
-      log.leaderEpochCache.flatMap(_.latestEntry))
+    assertEquals(Some(new EpochEntry(19, 27)),
+      log.leaderEpochCache.flatMap(_.latestEntry.asScala))
     assertEquals(29, log.logEndOffset)
 
     def verifyTruncationClearsEpochCache(epoch: Int, truncationOffset: Long): 
Unit = {
       // Simulate becoming a leader
       log.maybeAssignEpochStartOffset(leaderEpoch = epoch, startOffset = 
log.logEndOffset)
-      assertEquals(Some(EpochEntry(epoch = epoch, startOffset = 29)),
-        log.leaderEpochCache.flatMap(_.latestEntry))
+      assertEquals(Some(new EpochEntry(epoch, 29)),
+        log.leaderEpochCache.flatMap(_.latestEntry.asScala))
       assertEquals(29, log.logEndOffset)
 
       // Now we become the follower and truncate to an offset greater
       // than or equal to the log end offset. The trivial epoch entry
       // at the end of the log should be gone
       log.truncateTo(truncationOffset)
-      assertEquals(Some(EpochEntry(epoch = 19, startOffset = 27)),
-        log.leaderEpochCache.flatMap(_.latestEntry))
+      assertEquals(Some(new EpochEntry(19, 27)),
+        log.leaderEpochCache.flatMap(_.latestEntry.asScala))
       assertEquals(29, log.logEndOffset)
     }
 
@@ -2376,12 +2377,12 @@ class UnifiedLogTest {
     val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1000, 
indexIntervalBytes = 1, maxMessageBytes = 64 * 1024)
     val log = createLog(logDir, logConfig)
     log.appendAsLeader(TestUtils.records(List(new 
SimpleRecord("foo".getBytes()))), leaderEpoch = 5)
-    assertEquals(Some(5), log.leaderEpochCache.flatMap(_.latestEpoch))
+    assertEquals(Some(5), log.leaderEpochCache.flatMap(_.latestEpoch.asScala))
 
     log.appendAsFollower(TestUtils.records(List(new 
SimpleRecord("foo".getBytes())),
       baseOffset = 1L,
       magicValue = RecordVersion.V1.value))
-    assertEquals(None, log.leaderEpochCache.flatMap(_.latestEpoch))
+    assertEquals(None, log.leaderEpochCache.flatMap(_.latestEpoch.asScala))
   }
 
   @nowarn("cat=deprecation")
@@ -2540,7 +2541,7 @@ class UnifiedLogTest {
     log.deleteOldSegments()
     assertEquals(1, log.numberOfSegments, "The deleted segments should be 
gone.")
     assertEquals(1, epochCache(log).epochEntries.size, "Epoch entries should 
have gone.")
-    assertEquals(EpochEntry(1, 100), epochCache(log).epochEntries.head, "Epoch 
entry should be the latest epoch and the leo.")
+    assertEquals(new EpochEntry(1, 100), epochCache(log).epochEntries.get(0), 
"Epoch entry should be the latest epoch and the leo.")
 
     // append some messages to create some segments
     for (_ <- 0 until 100)
@@ -2791,7 +2792,7 @@ class UnifiedLogTest {
     log.deleteOldSegments()
 
     //The oldest epoch entry should have been removed
-    assertEquals(ListBuffer(EpochEntry(1, 5), EpochEntry(2, 10)), 
cache.epochEntries)
+    assertEquals(java.util.Arrays.asList(new EpochEntry(1, 5), new 
EpochEntry(2, 10)), cache.epochEntries)
   }
 
   @Test
@@ -2816,7 +2817,7 @@ class UnifiedLogTest {
     log.deleteOldSegments()
 
     //The first entry should have gone from (0,0) => (0,5)
-    assertEquals(ListBuffer(EpochEntry(0, 5), EpochEntry(1, 7), EpochEntry(2, 
10)), cache.epochEntries)
+    assertEquals(java.util.Arrays.asList(new EpochEntry(0, 5), new 
EpochEntry(1, 7), new EpochEntry(2, 10)), cache.epochEntries)
   }
 
   @Test
diff --git 
a/core/src/test/scala/unit/kafka/log/remote/RemoteLogManagerTest.scala 
b/core/src/test/scala/unit/kafka/log/remote/RemoteLogManagerTest.scala
index 8876dbd27f7..9c565fe364f 100644
--- a/core/src/test/scala/unit/kafka/log/remote/RemoteLogManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/remote/RemoteLogManagerTest.scala
@@ -19,16 +19,16 @@ package kafka.log.remote
 import kafka.cluster.Partition
 import kafka.log.UnifiedLog
 import kafka.server.KafkaConfig
-import kafka.server.checkpoints.LeaderEpochCheckpoint
-import kafka.server.epoch.{EpochEntry, LeaderEpochFileCache}
 import kafka.utils.MockTime
 import org.apache.kafka.common.config.AbstractConfig
 import org.apache.kafka.common.record.FileRecords.TimestampAndOffset
 import org.apache.kafka.common.record.{CompressionType, MemoryRecords, 
SimpleRecord}
 import org.apache.kafka.common.{KafkaException, TopicIdPartition, 
TopicPartition, Uuid}
-import org.apache.kafka.server.log.internals.{OffsetIndex, TimeIndex}
+import org.apache.kafka.server.log.internals._
 import 
org.apache.kafka.server.log.remote.storage.RemoteStorageManager.IndexType
 import org.apache.kafka.server.log.remote.storage._
+import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpoint
+import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache
 import org.apache.kafka.test.TestUtils
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{BeforeEach, Test}
@@ -63,8 +63,8 @@ class RemoteLogManagerTest {
 
   val checkpoint: LeaderEpochCheckpoint = new LeaderEpochCheckpoint {
     var epochs: Seq[EpochEntry] = Seq()
-    override def write(epochs: Iterable[EpochEntry]): Unit = this.epochs = 
epochs.toSeq
-    override def read(): Seq[EpochEntry] = this.epochs
+    override def write(epochs: util.Collection[EpochEntry]): Unit = 
this.epochs = epochs.asScala.toSeq
+    override def read(): util.List[EpochEntry] = this.epochs.asJava
   }
 
   @BeforeEach
@@ -227,9 +227,9 @@ class RemoteLogManagerTest {
       .thenAnswer(_ => new ByteArrayInputStream(records(ts, startOffset, 
targetLeaderEpoch).buffer().array()))
 
     val leaderEpochFileCache = new LeaderEpochFileCache(tp, checkpoint)
-    leaderEpochFileCache.assign(epoch = 5, startOffset = 99L)
-    leaderEpochFileCache.assign(epoch = targetLeaderEpoch, startOffset = 
startOffset)
-    leaderEpochFileCache.assign(epoch = 12, startOffset = 500L)
+    leaderEpochFileCache.assign(5, 99L)
+    leaderEpochFileCache.assign(targetLeaderEpoch, startOffset)
+    leaderEpochFileCache.assign(12, 500L)
 
     
remoteLogManager.onLeadershipChange(Set(mockPartition(leaderTopicIdPartition)), 
Set(), topicIds)
     // Fetching message for timestamp `ts` will return the message with 
startOffset+1, and `ts+1` as there are no
diff --git 
a/core/src/test/scala/unit/kafka/server/checkpoints/LeaderEpochCheckpointFileWithFailureHandlerTest.scala
 
b/core/src/test/scala/unit/kafka/server/checkpoints/LeaderEpochCheckpointFileWithFailureHandlerTest.scala
index cbd04c00452..a338de16f04 100644
--- 
a/core/src/test/scala/unit/kafka/server/checkpoints/LeaderEpochCheckpointFileWithFailureHandlerTest.scala
+++ 
b/core/src/test/scala/unit/kafka/server/checkpoints/LeaderEpochCheckpointFileWithFailureHandlerTest.scala
@@ -16,8 +16,9 @@
   */
 package kafka.server.checkpoints
 
-import kafka.server.epoch.EpochEntry
 import kafka.utils.{Logging, TestUtils}
+import org.apache.kafka.server.log.internals.{EpochEntry, LogDirFailureChannel}
+import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpointFile
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.Test
 
@@ -27,10 +28,10 @@ class LeaderEpochCheckpointFileWithFailureHandlerTest 
extends Logging {
   def shouldPersistAndOverwriteAndReloadFile(): Unit ={
     val file = TestUtils.tempFile("temp-checkpoint-file", 
System.nanoTime().toString)
 
-    val checkpoint = new LeaderEpochCheckpointFile(file)
+    val checkpoint = new LeaderEpochCheckpointFile(file, new 
LogDirFailureChannel(1))
 
     //Given
-    val epochs = Seq(EpochEntry(0, 1L), EpochEntry(1, 2L), EpochEntry(2, 3L))
+    val epochs = java.util.Arrays.asList(new EpochEntry(0, 1L), new 
EpochEntry(1, 2L), new EpochEntry(2, 3L))
 
     //When
     checkpoint.write(epochs)
@@ -39,7 +40,7 @@ class LeaderEpochCheckpointFileWithFailureHandlerTest extends 
Logging {
     assertEquals(epochs, checkpoint.read())
 
     //Given overwrite
-    val epochs2 = Seq(EpochEntry(3, 4L), EpochEntry(4, 5L))
+    val epochs2 = java.util.Arrays.asList(new EpochEntry(3, 4L), new 
EpochEntry(4, 5L))
 
     //When
     checkpoint.write(epochs2)
@@ -53,12 +54,12 @@ class LeaderEpochCheckpointFileWithFailureHandlerTest 
extends Logging {
     val file = TestUtils.tempFile("temp-checkpoint-file", 
System.nanoTime().toString)
 
     //Given a file with data in
-    val checkpoint = new LeaderEpochCheckpointFile(file)
-    val epochs = Seq(EpochEntry(0, 1L), EpochEntry(1, 2L), EpochEntry(2, 3L))
+    val checkpoint = new LeaderEpochCheckpointFile(file, new 
LogDirFailureChannel(1))
+    val epochs = java.util.Arrays.asList(new EpochEntry(0, 1L), new 
EpochEntry(1, 2L), new EpochEntry(2, 3L))
     checkpoint.write(epochs)
 
     //When we recreate
-    val checkpoint2 = new LeaderEpochCheckpointFile(file)
+    val checkpoint2 = new LeaderEpochCheckpointFile(file, new 
LogDirFailureChannel(1))
 
     //The data should still be there
     assertEquals(epochs, checkpoint2.read())
diff --git 
a/core/src/test/scala/unit/kafka/server/checkpoints/OffsetCheckpointFileWithFailureHandlerTest.scala
 
b/core/src/test/scala/unit/kafka/server/checkpoints/OffsetCheckpointFileWithFailureHandlerTest.scala
index 9439f388d43..c1ddb89831c 100644
--- 
a/core/src/test/scala/unit/kafka/server/checkpoints/OffsetCheckpointFileWithFailureHandlerTest.scala
+++ 
b/core/src/test/scala/unit/kafka/server/checkpoints/OffsetCheckpointFileWithFailureHandlerTest.scala
@@ -20,10 +20,12 @@ import kafka.utils.{Logging, TestUtils}
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.errors.KafkaStorageException
 import org.apache.kafka.server.log.internals.LogDirFailureChannel
+import 
org.apache.kafka.storage.internals.checkpoint.CheckpointFileWithFailureHandler
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.Test
 import org.mockito.Mockito
 
+import java.util.Collections
 import scala.collection.Map
 
 class OffsetCheckpointFileWithFailureHandlerTest extends Logging {
@@ -95,7 +97,7 @@ class OffsetCheckpointFileWithFailureHandlerTest extends 
Logging {
     val logDirFailureChannel = new LogDirFailureChannel(10)
     val checkpointFile = new CheckpointFileWithFailureHandler(file, 
OffsetCheckpointFile.CurrentVersion + 1,
       OffsetCheckpointFile.Formatter, logDirFailureChannel, file.getParent)
-    checkpointFile.write(Seq(new TopicPartition("foo", 5) -> 10L))
+    checkpointFile.write(Collections.singletonList(new TopicPartition("foo", 
5) -> 10L))
     assertThrows(classOf[KafkaStorageException], () => new 
OffsetCheckpointFile(checkpointFile.file, logDirFailureChannel).read())
   }
 
diff --git 
a/core/src/test/scala/unit/kafka/server/epoch/EpochDrivenReplicationProtocolAcceptanceTest.scala
 
b/core/src/test/scala/unit/kafka/server/epoch/EpochDrivenReplicationProtocolAcceptanceTest.scala
index 9e347656846..3ecd5814fa2 100644
--- 
a/core/src/test/scala/unit/kafka/server/epoch/EpochDrivenReplicationProtocolAcceptanceTest.scala
+++ 
b/core/src/test/scala/unit/kafka/server/epoch/EpochDrivenReplicationProtocolAcceptanceTest.scala
@@ -17,27 +17,27 @@
 
 package kafka.server.epoch
 
-import java.io.{File, RandomAccessFile}
-import java.util.Properties
-import kafka.log.{UnifiedLog, LogLoader}
+import kafka.log.{LogLoader, UnifiedLog}
 import kafka.server.KafkaConfig._
-import kafka.server.{KafkaConfig, KafkaServer}
+import kafka.server.{KafkaConfig, KafkaServer, QuorumTestHarness}
 import kafka.tools.DumpLogSegments
-import kafka.utils.{CoreUtils, Logging, TestUtils}
 import kafka.utils.TestUtils._
-import kafka.server.QuorumTestHarness
+import kafka.utils.{CoreUtils, Logging, TestUtils}
 import org.apache.kafka.clients.consumer.{ConsumerConfig, KafkaConsumer}
 import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord}
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.record.RecordBatch
 import org.apache.kafka.common.serialization.ByteArrayDeserializer
 import org.apache.kafka.server.common.MetadataVersion
+import org.apache.kafka.server.log.internals.EpochEntry
+import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache
 import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue}
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo}
 
-import scala.jdk.CollectionConverters._
-import scala.collection.mutable.{ListBuffer => Buffer}
+import java.io.{File, RandomAccessFile}
+import java.util.{Collections, Properties}
 import scala.collection.Seq
+import scala.jdk.CollectionConverters._
 
 /**
   * These tests were written to assert the addition of leader epochs to the 
replication protocol fix the problems
@@ -89,23 +89,23 @@ class EpochDrivenReplicationProtocolAcceptanceTest extends 
QuorumTestHarness wit
     assertEquals(0, latestRecord(follower).partitionLeaderEpoch)
 
     //Both leader and follower should have recorded Epoch 0 at Offset 0
-    assertEquals(Buffer(EpochEntry(0, 0)), epochCache(leader).epochEntries)
-    assertEquals(Buffer(EpochEntry(0, 0)), epochCache(follower).epochEntries)
+    assertEquals(Collections.singletonList(new EpochEntry(0, 0)), 
epochCache(leader).epochEntries)
+    assertEquals(Collections.singletonList(new EpochEntry(0, 0)), 
epochCache(follower).epochEntries)
 
     //Bounce the follower
     bounce(follower)
     awaitISR(tp)
 
     //Nothing happens yet as we haven't sent any new messages.
-    assertEquals(Buffer(EpochEntry(0, 0), EpochEntry(1, 1)), 
epochCache(leader).epochEntries)
-    assertEquals(Buffer(EpochEntry(0, 0)), epochCache(follower).epochEntries)
+    assertEquals(java.util.Arrays.asList(new EpochEntry(0, 0), new 
EpochEntry(1, 1)), epochCache(leader).epochEntries)
+    assertEquals(Collections.singletonList(new EpochEntry(0, 0)), 
epochCache(follower).epochEntries)
 
     //Send a message
     producer.send(new ProducerRecord(topic, 0, null, msg)).get
 
     //Epoch1 should now propagate to the follower with the written message
-    assertEquals(Buffer(EpochEntry(0, 0), EpochEntry(1, 1)), 
epochCache(leader).epochEntries)
-    assertEquals(Buffer(EpochEntry(0, 0), EpochEntry(1, 1)), 
epochCache(follower).epochEntries)
+    assertEquals(java.util.Arrays.asList(new EpochEntry(0, 0), new 
EpochEntry(1, 1)), epochCache(leader).epochEntries)
+    assertEquals(java.util.Arrays.asList(new EpochEntry(0, 0), new 
EpochEntry(1, 1)), epochCache(follower).epochEntries)
 
     //The new message should have epoch 1 stamped
     assertEquals(1, latestRecord(leader).partitionLeaderEpoch())
@@ -116,8 +116,8 @@ class EpochDrivenReplicationProtocolAcceptanceTest extends 
QuorumTestHarness wit
     awaitISR(tp)
 
     //Epochs 2 should be added to the leader, but not on the follower (yet), 
as there has been no replication.
-    assertEquals(Buffer(EpochEntry(0, 0), EpochEntry(1, 1), EpochEntry(2, 2)), 
epochCache(leader).epochEntries)
-    assertEquals(Buffer(EpochEntry(0, 0), EpochEntry(1, 1)), 
epochCache(follower).epochEntries)
+    assertEquals(java.util.Arrays.asList(new EpochEntry(0, 0), new 
EpochEntry(1, 1), new EpochEntry(2, 2)), epochCache(leader).epochEntries)
+    assertEquals(java.util.Arrays.asList(new EpochEntry(0, 0), new 
EpochEntry(1, 1)), epochCache(follower).epochEntries)
 
     //Send a message
     producer.send(new ProducerRecord(topic, 0, null, msg)).get
@@ -127,8 +127,8 @@ class EpochDrivenReplicationProtocolAcceptanceTest extends 
QuorumTestHarness wit
     assertEquals(2, latestRecord(follower).partitionLeaderEpoch())
 
     //The leader epoch files should now match on leader and follower
-    assertEquals(Buffer(EpochEntry(0, 0), EpochEntry(1, 1), EpochEntry(2, 2)), 
epochCache(leader).epochEntries)
-    assertEquals(Buffer(EpochEntry(0, 0), EpochEntry(1, 1), EpochEntry(2, 2)), 
epochCache(follower).epochEntries)
+    assertEquals(java.util.Arrays.asList(new EpochEntry(0, 0), new 
EpochEntry(1, 1), new EpochEntry(2, 2)), epochCache(leader).epochEntries)
+    assertEquals(java.util.Arrays.asList(new EpochEntry(0, 0), new 
EpochEntry(1, 1), new EpochEntry(2, 2)), epochCache(follower).epochEntries)
   }
 
   @Test
diff --git 
a/core/src/test/scala/unit/kafka/server/epoch/LeaderEpochFileCacheTest.scala 
b/core/src/test/scala/unit/kafka/server/epoch/LeaderEpochFileCacheTest.scala
index 63568fdb3e7..b66a0b18d11 100644
--- a/core/src/test/scala/unit/kafka/server/epoch/LeaderEpochFileCacheTest.scala
+++ b/core/src/test/scala/unit/kafka/server/epoch/LeaderEpochFileCacheTest.scala
@@ -17,18 +17,20 @@
 
 package kafka.server.epoch
 
-import java.io.File
-
-import scala.collection.Seq
-import scala.collection.mutable.ListBuffer
-
-import kafka.server.checkpoints.{LeaderEpochCheckpoint, 
LeaderEpochCheckpointFile}
 import kafka.utils.TestUtils
-import 
org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.{UNDEFINED_EPOCH,
 UNDEFINED_EPOCH_OFFSET}
 import org.apache.kafka.common.TopicPartition
+import 
org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.{UNDEFINED_EPOCH,
 UNDEFINED_EPOCH_OFFSET}
+import org.apache.kafka.server.log.internals.{EpochEntry, LogDirFailureChannel}
+import org.apache.kafka.storage.internals.checkpoint.{LeaderEpochCheckpoint, 
LeaderEpochCheckpointFile}
+import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.Test
 
+import java.io.File
+import java.util.{Collections, OptionalInt}
+import scala.collection.Seq
+import scala.jdk.CollectionConverters._
+
 /**
   * Unit test for the LeaderEpochFileCache.
   */
@@ -36,49 +38,50 @@ class LeaderEpochFileCacheTest {
   val tp = new TopicPartition("TestTopic", 5)
   private val checkpoint: LeaderEpochCheckpoint = new LeaderEpochCheckpoint {
     private var epochs: Seq[EpochEntry] = Seq()
-    override def write(epochs: Iterable[EpochEntry]): Unit = this.epochs = 
epochs.toSeq
-    override def read(): Seq[EpochEntry] = this.epochs
+    override def write(epochs: java.util.Collection[EpochEntry]): Unit = 
this.epochs = epochs.asScala.toSeq
+    override def read(): java.util.List[EpochEntry] = this.epochs.asJava
   }
+
   private val cache = new LeaderEpochFileCache(tp, checkpoint)
 
   @Test
   def testPreviousEpoch(): Unit = {
-    assertEquals(None, cache.previousEpoch)
+    assertEquals(OptionalInt.empty(), cache.previousEpoch)
 
-    cache.assign(epoch = 2, startOffset = 10)
-    assertEquals(None, cache.previousEpoch)
+    cache.assign(2, 10)
+    assertEquals(OptionalInt.empty(), cache.previousEpoch)
 
-    cache.assign(epoch = 4, startOffset = 15)
-    assertEquals(Some(2), cache.previousEpoch)
+    cache.assign(4, 15)
+    assertEquals(OptionalInt.of(2), cache.previousEpoch)
 
-    cache.assign(epoch = 10, startOffset = 20)
-    assertEquals(Some(4), cache.previousEpoch)
+    cache.assign(10, 20)
+    assertEquals(OptionalInt.of(4), cache.previousEpoch)
 
     cache.truncateFromEnd(18)
-    assertEquals(Some(2), cache.previousEpoch)
+    assertEquals(OptionalInt.of(2), cache.previousEpoch)
   }
 
   @Test
   def shouldAddEpochAndMessageOffsetToCache() = {
     //When
-    cache.assign(epoch = 2, startOffset = 10)
+    cache.assign(2, 10)
     val logEndOffset = 11
 
     //Then
-    assertEquals(Some(2), cache.latestEpoch)
-    assertEquals(EpochEntry(2, 10), cache.epochEntries(0))
-    assertEquals((2, logEndOffset), cache.endOffsetFor(2, logEndOffset)) 
//should match logEndOffset
+    assertEquals(OptionalInt.of(2), cache.latestEpoch)
+    assertEquals(new EpochEntry(2, 10), cache.epochEntries().get(0))
+    assertEquals((2, logEndOffset), toTuple(cache.endOffsetFor(2, 
logEndOffset))) //should match logEndOffset
   }
 
   @Test
   def shouldReturnLogEndOffsetIfLatestEpochRequested() = {
     //When just one epoch
-    cache.assign(epoch = 2, startOffset = 11)
-    cache.assign(epoch = 2, startOffset = 12)
+    cache.assign(2, 11)
+    cache.assign(2, 12)
     val logEndOffset = 14
 
     //Then
-    assertEquals((2, logEndOffset), cache.endOffsetFor(2, logEndOffset))
+    assertEquals((2, logEndOffset), toTuple(cache.endOffsetFor(2, 
logEndOffset)))
   }
 
   @Test
@@ -86,11 +89,11 @@ class LeaderEpochFileCacheTest {
     val expectedEpochEndOffset = (UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET)
 
     // assign couple of epochs
-    cache.assign(epoch = 2, startOffset = 11)
-    cache.assign(epoch = 3, startOffset = 12)
+    cache.assign(2, 11)
+    cache.assign(3, 12)
 
     //When (say a bootstraping follower) sends request for UNDEFINED_EPOCH
-    val epochAndOffsetFor = cache.endOffsetFor(UNDEFINED_EPOCH, 0L)
+    val epochAndOffsetFor = toTuple(cache.endOffsetFor(UNDEFINED_EPOCH, 0L))
 
     //Then
     assertEquals(expectedEpochEndOffset,
@@ -108,8 +111,8 @@ class LeaderEpochFileCacheTest {
     cache.assign(2, 10)
 
     //Then the offset should NOT have been updated
-    assertEquals(logEndOffset, cache.epochEntries(0).startOffset)
-    assertEquals(ListBuffer(EpochEntry(2, 9)), cache.epochEntries)
+    assertEquals(logEndOffset, cache.epochEntries.get(0).startOffset)
+    assertEquals(java.util.Arrays.asList(new EpochEntry(2, 9)), 
cache.epochEntries())
   }
 
   @Test
@@ -121,7 +124,7 @@ class LeaderEpochFileCacheTest {
     cache.assign(3, 9)
 
     //Then epoch should have been updated
-    assertEquals(ListBuffer(EpochEntry(3, 9)), cache.epochEntries)
+    assertEquals(java.util.Arrays.asList(new EpochEntry(3, 9)), 
cache.epochEntries)
   }
 
   @Test
@@ -132,19 +135,19 @@ class LeaderEpochFileCacheTest {
     cache.assign(2, 10)
 
     //Then later update should have been ignored
-    assertEquals(6, cache.epochEntries(0).startOffset)
+    assertEquals(6, cache.epochEntries.get(0).startOffset)
   }
 
   @Test
   def shouldReturnUnsupportedIfNoEpochRecorded(): Unit = {
     //Then
-    assertEquals((UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET), 
cache.endOffsetFor(0, 0L))
+    assertEquals((UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET), 
toTuple(cache.endOffsetFor(0, 0L)))
   }
 
   @Test
   def shouldReturnUnsupportedIfNoEpochRecordedAndUndefinedEpochRequested(): 
Unit = {
     //When (say a follower on older message format version) sends request for 
UNDEFINED_EPOCH
-    val offsetFor = cache.endOffsetFor(UNDEFINED_EPOCH, 73)
+    val offsetFor = toTuple(cache.endOffsetFor(UNDEFINED_EPOCH, 73))
 
     //Then
     assertEquals((UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET),
@@ -153,12 +156,12 @@ class LeaderEpochFileCacheTest {
 
   @Test
   def shouldReturnFirstEpochIfRequestedEpochLessThanFirstEpoch(): Unit = {
-    cache.assign(epoch = 5, startOffset = 11)
-    cache.assign(epoch = 6, startOffset = 12)
-    cache.assign(epoch = 7, startOffset = 13)
+    cache.assign(5, 11)
+    cache.assign(6, 12)
+    cache.assign(7, 13)
 
     //When
-    val epochAndOffset = cache.endOffsetFor(4, 0L)
+    val epochAndOffset = toTuple(cache.endOffsetFor(4, 0L))
 
     //Then
     assertEquals((4, 11), epochAndOffset)
@@ -166,100 +169,100 @@ class LeaderEpochFileCacheTest {
 
   @Test
   def shouldTruncateIfMatchingEpochButEarlierStartingOffset(): Unit = {
-    cache.assign(epoch = 5, startOffset = 11)
-    cache.assign(epoch = 6, startOffset = 12)
-    cache.assign(epoch = 7, startOffset = 13)
+    cache.assign(5, 11)
+    cache.assign(6, 12)
+    cache.assign(7, 13)
 
     // epoch 7 starts at an earlier offset
-    cache.assign(epoch = 7, startOffset = 12)
+    cache.assign(7, 12)
 
-    assertEquals((5, 12), cache.endOffsetFor(5, 0L))
-    assertEquals((5, 12), cache.endOffsetFor(6, 0L))
+    assertEquals((5, 12), toTuple(cache.endOffsetFor(5, 0L)))
+    assertEquals((5, 12), toTuple(cache.endOffsetFor(6, 0L)))
   }
 
   @Test
   def 
shouldGetFirstOffsetOfSubsequentEpochWhenOffsetRequestedForPreviousEpoch() = {
     //When several epochs
-    cache.assign(epoch = 1, startOffset = 11)
-    cache.assign(epoch = 1, startOffset = 12)
-    cache.assign(epoch = 2, startOffset = 13)
-    cache.assign(epoch = 2, startOffset = 14)
-    cache.assign(epoch = 3, startOffset = 15)
-    cache.assign(epoch = 3, startOffset = 16)
+    cache.assign(1, 11)
+    cache.assign(1, 12)
+    cache.assign(2, 13)
+    cache.assign(2, 14)
+    cache.assign(3, 15)
+    cache.assign(3, 16)
 
     //Then get the start offset of the next epoch
-    assertEquals((2, 15), cache.endOffsetFor(2, 17))
+    assertEquals((2, 15), toTuple(cache.endOffsetFor(2, 17)))
   }
 
   @Test
   def shouldReturnNextAvailableEpochIfThereIsNoExactEpochForTheOneRequested(): 
Unit = {
     //When
-    cache.assign(epoch = 0, startOffset = 10)
-    cache.assign(epoch = 2, startOffset = 13)
-    cache.assign(epoch = 4, startOffset = 17)
+    cache.assign(0, 10)
+    cache.assign(2, 13)
+    cache.assign(4, 17)
 
     //Then
-    assertEquals((0, 13), cache.endOffsetFor(1, 0L))
-    assertEquals((2, 17), cache.endOffsetFor(2, 0L))
-    assertEquals((2, 17), cache.endOffsetFor(3, 0L))
+    assertEquals((0, 13), toTuple(cache.endOffsetFor(1, 0L)))
+    assertEquals((2, 17), toTuple(cache.endOffsetFor(2, 0L)))
+    assertEquals((2, 17), toTuple(cache.endOffsetFor(3, 0L)))
   }
 
   @Test
   def shouldNotUpdateEpochAndStartOffsetIfItDidNotChange() = {
     //When
-    cache.assign(epoch = 2, startOffset = 6)
-    cache.assign(epoch = 2, startOffset = 7)
+    cache.assign(2, 6)
+    cache.assign(2, 7)
 
     //Then
     assertEquals(1, cache.epochEntries.size)
-    assertEquals(EpochEntry(2, 6), cache.epochEntries.toList(0))
+    assertEquals(new EpochEntry(2, 6), cache.epochEntries.get(0))
   }
 
   @Test
   def shouldReturnInvalidOffsetIfEpochIsRequestedWhichIsNotCurrentlyTracked(): 
Unit = {
     //When
-    cache.assign(epoch = 2, startOffset = 100)
+    cache.assign(2, 100)
 
     //Then
-    assertEquals((UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET), 
cache.endOffsetFor(3, 100))
+    assertEquals((UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET), 
toTuple(cache.endOffsetFor(3, 100)))
   }
 
   @Test
   def shouldSupportEpochsThatDoNotStartFromZero(): Unit = {
     //When
-    cache.assign(epoch = 2, startOffset = 6)
+    cache.assign(2, 6)
     val logEndOffset = 7
 
     //Then
-    assertEquals((2, logEndOffset), cache.endOffsetFor(2, logEndOffset))
+    assertEquals((2, logEndOffset), toTuple(cache.endOffsetFor(2, 
logEndOffset)))
     assertEquals(1, cache.epochEntries.size)
-    assertEquals(EpochEntry(2, 6), cache.epochEntries(0))
+    assertEquals(new EpochEntry(2, 6), cache.epochEntries.get(0))
   }
 
   @Test
   def shouldPersistEpochsBetweenInstances(): Unit = {
     val checkpointPath = TestUtils.tempFile().getAbsolutePath
-    val checkpoint = new LeaderEpochCheckpointFile(new File(checkpointPath))
+    val checkpoint = new LeaderEpochCheckpointFile(new File(checkpointPath), 
new LogDirFailureChannel(1))
 
     //Given
     val cache = new LeaderEpochFileCache(tp, checkpoint)
-    cache.assign(epoch = 2, startOffset = 6)
+    cache.assign(2, 6)
 
     //When
-    val checkpoint2 = new LeaderEpochCheckpointFile(new File(checkpointPath))
+    val checkpoint2 = new LeaderEpochCheckpointFile(new File(checkpointPath), 
new LogDirFailureChannel(1))
     val cache2 = new LeaderEpochFileCache(tp, checkpoint2)
 
     //Then
     assertEquals(1, cache2.epochEntries.size)
-    assertEquals(EpochEntry(2, 6), cache2.epochEntries.toList(0))
+    assertEquals(new EpochEntry(2, 6), cache2.epochEntries.get(0))
   }
 
   @Test
   def shouldEnforceMonotonicallyIncreasingEpochs(): Unit = {
     //Given
-    cache.assign(epoch = 1, startOffset = 5);
+    cache.assign(1, 5);
     var logEndOffset = 6
-    cache.assign(epoch = 2, startOffset = 6);
+    cache.assign(2, 6);
     logEndOffset = 7
 
     //When we update an epoch in the past with a different offset, the log has 
already reached
@@ -267,28 +270,32 @@ class LeaderEpochFileCacheTest {
     //or truncate the cached epochs to the point of conflict. We take this 
latter approach in
     //order to guarantee that epochs and offsets in the cache increase 
monotonically, which makes
     //the search logic simpler to reason about.
-    cache.assign(epoch = 1, startOffset = 7);
+    cache.assign(1, 7);
     logEndOffset = 8
 
     //Then later epochs will be removed
-    assertEquals(Some(1), cache.latestEpoch)
+    assertEquals(OptionalInt.of(1), cache.latestEpoch)
 
     //Then end offset for epoch 1 will have changed
-    assertEquals((1, 8), cache.endOffsetFor(1, logEndOffset))
+    assertEquals((1, 8), toTuple(cache.endOffsetFor(1, logEndOffset)))
 
     //Then end offset for epoch 2 is now undefined
-    assertEquals((UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET), 
cache.endOffsetFor(2, logEndOffset))
-    assertEquals(EpochEntry(1, 7), cache.epochEntries(0))
+    assertEquals((UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET), 
toTuple(cache.endOffsetFor(2, logEndOffset)))
+    assertEquals(new EpochEntry(1, 7), cache.epochEntries.get(0))
+  }
+
+  private def toTuple[K, V](entry: java.util.Map.Entry[K, V]): (K, V) = {
+    (entry.getKey, entry.getValue)
   }
 
   @Test
   def shouldEnforceOffsetsIncreaseMonotonically() = {
     //When epoch goes forward but offset goes backwards
-    cache.assign(epoch = 2, startOffset = 6)
-    cache.assign(epoch = 3, startOffset = 5)
+    cache.assign(2, 6)
+    cache.assign(3, 5)
 
     //The last assignment wins and the conflicting one is removed from the log
-    assertEquals(EpochEntry(3, 5), cache.epochEntries.toList(0))
+    assertEquals(new EpochEntry(3, 5), cache.epochEntries.get(0))
   }
 
   @Test
@@ -296,229 +303,230 @@ class LeaderEpochFileCacheTest {
     var logEndOffset = 0L
 
     //Given
-    cache.assign(epoch = 0, startOffset = 0) //logEndOffset=0
+    cache.assign(0, 0) //logEndOffset=0
 
     //When
-    cache.assign(epoch = 1, startOffset = 0) //logEndOffset=0
+    cache.assign(1, 0) //logEndOffset=0
 
     //Then epoch should go up
-    assertEquals(Some(1), cache.latestEpoch)
+    assertEquals(OptionalInt.of(1), cache.latestEpoch)
     //offset for 1 should still be 0
-    assertEquals((1, 0), cache.endOffsetFor(1, logEndOffset))
+    assertEquals((1, 0), toTuple(cache.endOffsetFor(1, logEndOffset)))
     //offset for epoch 0 should still be 0
-    assertEquals((0, 0), cache.endOffsetFor(0, logEndOffset))
+    assertEquals((0, 0), toTuple(cache.endOffsetFor(0, logEndOffset)))
 
     //When we write 5 messages as epoch 1
     logEndOffset = 5L
 
     //Then end offset for epoch(1) should be logEndOffset => 5
-    assertEquals((1, 5), cache.endOffsetFor(1, logEndOffset))
+    assertEquals((1, 5), toTuple(cache.endOffsetFor(1, logEndOffset)))
     //Epoch 0 should still be at offset 0
-    assertEquals((0, 0), cache.endOffsetFor(0, logEndOffset))
+    assertEquals((0, 0), toTuple(cache.endOffsetFor(0, logEndOffset)))
 
     //When
-    cache.assign(epoch = 2, startOffset = 5) //logEndOffset=5
+    cache.assign(2, 5) //logEndOffset=5
 
     logEndOffset = 10 //write another 5 messages
 
     //Then end offset for epoch(2) should be logEndOffset => 10
-    assertEquals((2, 10), cache.endOffsetFor(2, logEndOffset))
+    assertEquals((2, 10), toTuple(cache.endOffsetFor(2, logEndOffset)))
 
     //end offset for epoch(1) should be the start offset of epoch(2) => 5
-    assertEquals((1, 5), cache.endOffsetFor(1, logEndOffset))
+    assertEquals((1, 5), toTuple(cache.endOffsetFor(1, logEndOffset)))
 
     //epoch (0) should still be 0
-    assertEquals((0, 0), cache.endOffsetFor(0, logEndOffset))
+    assertEquals((0, 0), toTuple(cache.endOffsetFor(0, logEndOffset)))
   }
 
   @Test
   def shouldIncreaseAndTrackEpochsAsFollowerReceivesManyMessages(): Unit = {
     //When Messages come in
-    cache.assign(epoch = 0, startOffset = 0);
+    cache.assign(0, 0);
     var logEndOffset = 1
-    cache.assign(epoch = 0, startOffset = 1);
+    cache.assign(0, 1);
     logEndOffset = 2
-    cache.assign(epoch = 0, startOffset = 2);
+    cache.assign(0, 2);
     logEndOffset = 3
 
     //Then epoch should stay, offsets should grow
-    assertEquals(Some(0), cache.latestEpoch)
-    assertEquals((0, logEndOffset), cache.endOffsetFor(0, logEndOffset))
+    assertEquals(OptionalInt.of(0), cache.latestEpoch)
+    assertEquals((0, logEndOffset), toTuple(cache.endOffsetFor(0, 
logEndOffset)))
 
     //When messages arrive with greater epoch
-    cache.assign(epoch = 1, startOffset = 3);
+    cache.assign(1, 3);
     logEndOffset = 4
-    cache.assign(epoch = 1, startOffset = 4);
+    cache.assign(1, 4);
     logEndOffset = 5
-    cache.assign(epoch = 1, startOffset = 5);
+    cache.assign(1, 5);
     logEndOffset = 6
 
-    assertEquals(Some(1), cache.latestEpoch)
-    assertEquals((1, logEndOffset), cache.endOffsetFor(1, logEndOffset))
+    assertEquals(OptionalInt.of(1), cache.latestEpoch)
+    assertEquals((1, logEndOffset), toTuple(cache.endOffsetFor(1, 
logEndOffset)))
 
     //When
-    cache.assign(epoch = 2, startOffset = 6);
+    cache.assign(2, 6);
     logEndOffset = 7
-    cache.assign(epoch = 2, startOffset = 7);
+    cache.assign(2, 7);
     logEndOffset = 8
-    cache.assign(epoch = 2, startOffset = 8);
+    cache.assign(2, 8);
     logEndOffset = 9
 
-    assertEquals(Some(2), cache.latestEpoch)
-    assertEquals((2, logEndOffset), cache.endOffsetFor(2, logEndOffset))
+    assertEquals(OptionalInt.of(2), cache.latestEpoch)
+    assertEquals((2, logEndOffset), toTuple(cache.endOffsetFor(2, 
logEndOffset)))
 
     //Older epochs should return the start offset of the first message in the 
subsequent epoch.
-    assertEquals((0, 3), cache.endOffsetFor(0, logEndOffset))
-    assertEquals((1, 6), cache.endOffsetFor(1, logEndOffset))
+    assertEquals((0, 3), toTuple(cache.endOffsetFor(0, logEndOffset)))
+    assertEquals((1, 6), toTuple(cache.endOffsetFor(1, logEndOffset)))
   }
 
   @Test
   def shouldDropEntriesOnEpochBoundaryWhenRemovingLatestEntries(): Unit = {
     //Given
-    cache.assign(epoch = 2, startOffset = 6)
-    cache.assign(epoch = 3, startOffset = 8)
-    cache.assign(epoch = 4, startOffset = 11)
+    cache.assign(2, 6)
+    cache.assign(3, 8)
+    cache.assign(4, 11)
 
     //When clear latest on epoch boundary
-    cache.truncateFromEnd(endOffset = 8)
+    cache.truncateFromEnd(8)
 
     //Then should remove two latest epochs (remove is inclusive)
-    assertEquals(ListBuffer(EpochEntry(2, 6)), cache.epochEntries)
+    assertEquals(java.util.Arrays.asList(new EpochEntry(2, 6)), 
cache.epochEntries)
   }
 
   @Test
   def shouldPreserveResetOffsetOnClearEarliestIfOneExists(): Unit = {
     //Given
-    cache.assign(epoch = 2, startOffset = 6)
-    cache.assign(epoch = 3, startOffset = 8)
-    cache.assign(epoch = 4, startOffset = 11)
+    cache.assign(2, 6)
+    cache.assign(3, 8)
+    cache.assign(4, 11)
 
     //When reset to offset ON epoch boundary
-    cache.truncateFromStart(startOffset = 8)
+    cache.truncateFromStart(8)
 
     //Then should preserve (3, 8)
-    assertEquals(ListBuffer(EpochEntry(3, 8), EpochEntry(4, 11)), 
cache.epochEntries)
+    assertEquals(java.util.Arrays.asList(new EpochEntry(3, 8), new 
EpochEntry(4, 11)), cache.epochEntries)
   }
 
   @Test
   def shouldUpdateSavedOffsetWhenOffsetToClearToIsBetweenEpochs(): Unit = {
     //Given
-    cache.assign(epoch = 2, startOffset = 6)
-    cache.assign(epoch = 3, startOffset = 8)
-    cache.assign(epoch = 4, startOffset = 11)
+    cache.assign(2, 6)
+    cache.assign(3, 8)
+    cache.assign(4, 11)
 
     //When reset to offset BETWEEN epoch boundaries
-    cache.truncateFromStart(startOffset = 9)
+    cache.truncateFromStart(9)
 
     //Then we should retain epoch 3, but update it's offset to 9 as 8 has been 
removed
-    assertEquals(ListBuffer(EpochEntry(3, 9), EpochEntry(4, 11)), 
cache.epochEntries)
+    assertEquals(java.util.Arrays.asList(new EpochEntry(3, 9), new 
EpochEntry(4, 11)), cache.epochEntries)
   }
 
   @Test
   def shouldNotClearAnythingIfOffsetToEarly(): Unit = {
     //Given
-    cache.assign(epoch = 2, startOffset = 6)
-    cache.assign(epoch = 3, startOffset = 8)
-    cache.assign(epoch = 4, startOffset = 11)
+    cache.assign(2, 6)
+    cache.assign(3, 8)
+    cache.assign(4, 11)
 
     //When reset to offset before first epoch offset
-    cache.truncateFromStart(startOffset = 1)
+    cache.truncateFromStart(1)
 
     //Then nothing should change
-    assertEquals(ListBuffer(EpochEntry(2, 6),EpochEntry(3, 8), EpochEntry(4, 
11)), cache.epochEntries)
+    assertEquals(java.util.Arrays.asList(new EpochEntry(2, 6),new 
EpochEntry(3, 8), new EpochEntry(4, 11)), cache.epochEntries)
   }
 
   @Test
   def shouldNotClearAnythingIfOffsetToFirstOffset(): Unit = {
     //Given
-    cache.assign(epoch = 2, startOffset = 6)
-    cache.assign(epoch = 3, startOffset = 8)
-    cache.assign(epoch = 4, startOffset = 11)
+    cache.assign(2, 6)
+    cache.assign(3, 8)
+    cache.assign(4, 11)
 
     //When reset to offset on earliest epoch boundary
-    cache.truncateFromStart(startOffset = 6)
+    cache.truncateFromStart(6)
 
     //Then nothing should change
-    assertEquals(ListBuffer(EpochEntry(2, 6),EpochEntry(3, 8), EpochEntry(4, 
11)), cache.epochEntries)
+    assertEquals(java.util.Arrays.asList(new EpochEntry(2, 6),new 
EpochEntry(3, 8), new EpochEntry(4, 11)), cache.epochEntries)
   }
 
   @Test
   def shouldRetainLatestEpochOnClearAllEarliest(): Unit = {
     //Given
-    cache.assign(epoch = 2, startOffset = 6)
-    cache.assign(epoch = 3, startOffset = 8)
-    cache.assign(epoch = 4, startOffset = 11)
+    cache.assign(2, 6)
+    cache.assign(3, 8)
+    cache.assign(4, 11)
 
     //When
-    cache.truncateFromStart(startOffset = 11)
+    cache.truncateFromStart(11)
 
     //Then retain the last
-    assertEquals(ListBuffer(EpochEntry(4, 11)), cache.epochEntries)
+    assertEquals(Collections.singletonList(new EpochEntry(4, 11)), 
cache.epochEntries)
   }
 
   @Test
   def shouldUpdateOffsetBetweenEpochBoundariesOnClearEarliest(): Unit = {
     //Given
-    cache.assign(epoch = 2, startOffset = 6)
-    cache.assign(epoch = 3, startOffset = 8)
-    cache.assign(epoch = 4, startOffset = 11)
+    cache.assign(2, 6)
+    cache.assign(3, 8)
+    cache.assign(4, 11)
 
     //When we clear from a position between offset 8 & offset 11
-    cache.truncateFromStart(startOffset = 9)
+    cache.truncateFromStart(9)
 
     //Then we should update the middle epoch entry's offset
-    assertEquals(ListBuffer(EpochEntry(3, 9), EpochEntry(4, 11)), 
cache.epochEntries)
+    assertEquals(java.util.Arrays.asList(new EpochEntry(3, 9), new 
EpochEntry(4, 11)), cache.epochEntries)
   }
 
   @Test
   def shouldUpdateOffsetBetweenEpochBoundariesOnClearEarliest2(): Unit = {
     //Given
-    cache.assign(epoch = 0, startOffset = 0)
-    cache.assign(epoch = 1, startOffset = 7)
-    cache.assign(epoch = 2, startOffset = 10)
+    cache.assign(0, 0)
+    cache.assign(1, 7)
+    cache.assign(2, 10)
 
     //When we clear from a position between offset 0 & offset 7
-    cache.truncateFromStart(startOffset = 5)
+    cache.truncateFromStart(5)
 
     //Then we should keep epoch 0 but update the offset appropriately
-    assertEquals(ListBuffer(EpochEntry(0,5), EpochEntry(1, 7), EpochEntry(2, 
10)), cache.epochEntries)
+    assertEquals(java.util.Arrays.asList(new EpochEntry(0,5), new 
EpochEntry(1, 7), new EpochEntry(2, 10)),
+      cache.epochEntries)
   }
 
   @Test
   def shouldRetainLatestEpochOnClearAllEarliestAndUpdateItsOffset(): Unit = {
     //Given
-    cache.assign(epoch = 2, startOffset = 6)
-    cache.assign(epoch = 3, startOffset = 8)
-    cache.assign(epoch = 4, startOffset = 11)
+    cache.assign(2, 6)
+    cache.assign(3, 8)
+    cache.assign(4, 11)
 
     //When reset to offset beyond last epoch
-    cache.truncateFromStart(startOffset = 15)
+    cache.truncateFromStart(15)
 
     //Then update the last
-    assertEquals(ListBuffer(EpochEntry(4, 15)), cache.epochEntries)
+    assertEquals(Collections.singletonList(new EpochEntry(4, 15)), 
cache.epochEntries)
   }
 
   @Test
   def shouldDropEntriesBetweenEpochBoundaryWhenRemovingNewest(): Unit = {
     //Given
-    cache.assign(epoch = 2, startOffset = 6)
-    cache.assign(epoch = 3, startOffset = 8)
-    cache.assign(epoch = 4, startOffset = 11)
+    cache.assign(2, 6)
+    cache.assign(3, 8)
+    cache.assign(4, 11)
 
     //When reset to offset BETWEEN epoch boundaries
-    cache.truncateFromEnd(endOffset = 9)
+    cache.truncateFromEnd( 9)
 
     //Then should keep the preceding epochs
-    assertEquals(Some(3), cache.latestEpoch)
-    assertEquals(ListBuffer(EpochEntry(2, 6), EpochEntry(3, 8)), 
cache.epochEntries)
+    assertEquals(OptionalInt.of(3), cache.latestEpoch)
+    assertEquals(java.util.Arrays.asList(new EpochEntry(2, 6), new 
EpochEntry(3, 8)), cache.epochEntries)
   }
 
   @Test
   def shouldClearAllEntries(): Unit = {
     //Given
-    cache.assign(epoch = 2, startOffset = 6)
-    cache.assign(epoch = 3, startOffset = 8)
-    cache.assign(epoch = 4, startOffset = 11)
+    cache.assign(2, 6)
+    cache.assign(3, 8)
+    cache.assign(4, 11)
 
     //When
     cache.clearAndFlush()
@@ -530,12 +538,12 @@ class LeaderEpochFileCacheTest {
   @Test
   def shouldNotResetEpochHistoryHeadIfUndefinedPassed(): Unit = {
     //Given
-    cache.assign(epoch = 2, startOffset = 6)
-    cache.assign(epoch = 3, startOffset = 8)
-    cache.assign(epoch = 4, startOffset = 11)
+    cache.assign(2, 6)
+    cache.assign(3, 8)
+    cache.assign(4, 11)
 
     //When reset to offset on epoch boundary
-    cache.truncateFromStart(startOffset = UNDEFINED_EPOCH_OFFSET)
+    cache.truncateFromStart(UNDEFINED_EPOCH_OFFSET)
 
     //Then should do nothing
     assertEquals(3, cache.epochEntries.size)
@@ -544,12 +552,12 @@ class LeaderEpochFileCacheTest {
   @Test
   def shouldNotResetEpochHistoryTailIfUndefinedPassed(): Unit = {
     //Given
-    cache.assign(epoch = 2, startOffset = 6)
-    cache.assign(epoch = 3, startOffset = 8)
-    cache.assign(epoch = 4, startOffset = 11)
+    cache.assign(2, 6)
+    cache.assign(3, 8)
+    cache.assign(4, 11)
 
     //When reset to offset on epoch boundary
-    cache.truncateFromEnd(endOffset = UNDEFINED_EPOCH_OFFSET)
+    cache.truncateFromEnd(UNDEFINED_EPOCH_OFFSET)
 
     //Then should do nothing
     assertEquals(3, cache.epochEntries.size)
@@ -558,13 +566,13 @@ class LeaderEpochFileCacheTest {
   @Test
   def shouldFetchLatestEpochOfEmptyCache(): Unit = {
     //Then
-    assertEquals(None, cache.latestEpoch)
+    assertEquals(OptionalInt.empty(), cache.latestEpoch)
   }
 
   @Test
   def shouldFetchEndOffsetOfEmptyCache(): Unit = {
     //Then
-    assertEquals((UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET), 
cache.endOffsetFor(7, 0L))
+    assertEquals((UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET), 
toTuple(cache.endOffsetFor(7, 0L)))
   }
 
   @Test
@@ -581,56 +589,56 @@ class LeaderEpochFileCacheTest {
 
   @Test
   def testFindPreviousEpoch(): Unit = {
-    assertEquals(None, cache.previousEpoch(epoch = 2))
+    assertEquals(OptionalInt.empty(), cache.previousEpoch(2))
 
-    cache.assign(epoch = 2, startOffset = 10)
-    assertEquals(None, cache.previousEpoch(epoch = 2))
+    cache.assign(2, 10)
+    assertEquals(OptionalInt.empty(), cache.previousEpoch(2))
 
-    cache.assign(epoch = 4, startOffset = 15)
-    assertEquals(Some(2), cache.previousEpoch(epoch = 4))
+    cache.assign(4, 15)
+    assertEquals(OptionalInt.of(2), cache.previousEpoch(4))
 
-    cache.assign(epoch = 10, startOffset = 20)
-    assertEquals(Some(4), cache.previousEpoch(epoch = 10))
+    cache.assign(10, 20)
+    assertEquals(OptionalInt.of(4), cache.previousEpoch(10))
 
     cache.truncateFromEnd(18)
-    assertEquals(Some(2), cache.previousEpoch(cache.latestEpoch.get))
+    assertEquals(OptionalInt.of(2), 
cache.previousEpoch(cache.latestEpoch.getAsInt))
   }
 
   @Test
   def testFindNextEpoch(): Unit = {
-    cache.assign(epoch = 0, startOffset = 0)
-    cache.assign(epoch = 1, startOffset = 100)
-    cache.assign(epoch = 2, startOffset = 200)
+    cache.assign(0, 0)
+    cache.assign(1, 100)
+    cache.assign(2, 200)
 
-    assertEquals(Some(0), cache.nextEpoch(epoch = -1))
-    assertEquals(Some(1), cache.nextEpoch(epoch = 0))
-    assertEquals(Some(2), cache.nextEpoch(epoch = 1))
-    assertEquals(None, cache.nextEpoch(epoch = 2))
-    assertEquals(None, cache.nextEpoch(epoch = 100))
+    assertEquals(OptionalInt.of(0), cache.nextEpoch(-1))
+    assertEquals(OptionalInt.of(1), cache.nextEpoch(0))
+    assertEquals(OptionalInt.of(2), cache.nextEpoch(1))
+    assertEquals(OptionalInt.empty(), cache.nextEpoch(2))
+    assertEquals(OptionalInt.empty(), cache.nextEpoch(100))
   }
 
   @Test
   def testGetEpochEntry(): Unit = {
-    cache.assign(epoch = 2, startOffset = 100)
-    cache.assign(epoch = 3, startOffset = 500)
-    cache.assign(epoch = 5, startOffset = 1000)
+    cache.assign(2, 100)
+    cache.assign(3, 500)
+    cache.assign(5, 1000)
 
-    assertEquals(EpochEntry(2, 100), cache.epochEntry(2).get)
-    assertEquals(EpochEntry(3, 500), cache.epochEntry(3).get)
-    assertEquals(EpochEntry(5, 1000), cache.epochEntry(5).get)
+    assertEquals(new EpochEntry(2, 100), cache.epochEntry(2).get)
+    assertEquals(new EpochEntry(3, 500), cache.epochEntry(3).get)
+    assertEquals(new EpochEntry(5, 1000), cache.epochEntry(5).get)
   }
 
   @Test
   def shouldFetchEpochForGivenOffset(): Unit = {
-    cache.assign(epoch = 0, startOffset = 10)
-    cache.assign(epoch = 1, startOffset = 20)
-    cache.assign(epoch = 5, startOffset = 30)
-
-    assertEquals(Some(1), cache.epochForOffset(offset = 25))
-    assertEquals(Some(1), cache.epochForOffset(offset = 20))
-    assertEquals(Some(5), cache.epochForOffset(offset = 30))
-    assertEquals(Some(5), cache.epochForOffset(offset = 50))
-    assertEquals(None, cache.epochForOffset(offset = 5))
+    cache.assign(0, 10)
+    cache.assign(1, 20)
+    cache.assign(5, 30)
+
+    assertEquals(OptionalInt.of(1), cache.epochForOffset(25))
+    assertEquals(OptionalInt.of(1), cache.epochForOffset(20))
+    assertEquals(OptionalInt.of(5), cache.epochForOffset(30))
+    assertEquals(OptionalInt.of(5), cache.epochForOffset(50))
+    assertEquals(OptionalInt.empty(), cache.epochForOffset(5))
   }
 
 }
diff --git 
a/storage/src/main/java/org/apache/kafka/server/log/internals/EpochEntry.java 
b/storage/src/main/java/org/apache/kafka/server/log/internals/EpochEntry.java
new file mode 100644
index 00000000000..330444b9617
--- /dev/null
+++ 
b/storage/src/main/java/org/apache/kafka/server/log/internals/EpochEntry.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.server.log.internals;
+
+// Mapping of epoch to the first offset of the subsequent epoch
+public class EpochEntry {
+    public final int epoch;
+    public final long startOffset;
+
+    public EpochEntry(int epoch, long startOffset) {
+        this.epoch = epoch;
+        this.startOffset = startOffset;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        EpochEntry that = (EpochEntry) o;
+        return epoch == that.epoch && startOffset == that.startOffset;
+    }
+
+    @Override
+    public int hashCode() {
+        int result = epoch;
+        result = 31 * result + Long.hashCode(startOffset);
+        return result;
+    }
+
+    @Override
+    public String toString() {
+        return "EpochEntry(" +
+                "epoch=" + epoch +
+                ", startOffset=" + startOffset +
+                ')';
+    }
+}
diff --git 
a/storage/src/main/java/org/apache/kafka/storage/internals/checkpoint/CheckpointFileWithFailureHandler.java
 
b/storage/src/main/java/org/apache/kafka/storage/internals/checkpoint/CheckpointFileWithFailureHandler.java
new file mode 100644
index 00000000000..04c2e1dd459
--- /dev/null
+++ 
b/storage/src/main/java/org/apache/kafka/storage/internals/checkpoint/CheckpointFileWithFailureHandler.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.storage.internals.checkpoint;
+
+import org.apache.kafka.common.errors.KafkaStorageException;
+import org.apache.kafka.server.common.CheckpointFile;
+import org.apache.kafka.server.log.internals.LogDirFailureChannel;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.List;
+
+public class CheckpointFileWithFailureHandler<T> {
+
+    public final File file;
+    private final LogDirFailureChannel logDirFailureChannel;
+    private final String logDir;
+
+    private final CheckpointFile<T> checkpointFile;
+
+    public CheckpointFileWithFailureHandler(File file, int version, 
CheckpointFile.EntryFormatter<T> formatter,
+                                            LogDirFailureChannel 
logDirFailureChannel, String logDir) throws IOException {
+        this.file = file;
+        this.logDirFailureChannel = logDirFailureChannel;
+        this.logDir = logDir;
+        checkpointFile = new CheckpointFile<>(file, version, formatter);
+    }
+
+    public void write(Collection<T> entries) {
+        try {
+            checkpointFile.write(entries);
+        } catch (IOException e) {
+            String msg = "Error while writing to checkpoint file " + 
file.getAbsolutePath();
+            logDirFailureChannel.maybeAddOfflineLogDir(logDir, msg, e);
+            throw new KafkaStorageException(msg, e);
+        }
+    }
+
+    public List<T> read() {
+        try {
+            return checkpointFile.read();
+        } catch (IOException e) {
+            String msg = "Error while reading checkpoint file " + 
file.getAbsolutePath();
+            logDirFailureChannel.maybeAddOfflineLogDir(logDir, msg, e);
+            throw new KafkaStorageException(msg, e);
+        }
+    }
+}
diff --git 
a/storage/src/main/java/org/apache/kafka/storage/internals/checkpoint/LeaderEpochCheckpoint.java
 
b/storage/src/main/java/org/apache/kafka/storage/internals/checkpoint/LeaderEpochCheckpoint.java
new file mode 100644
index 00000000000..1032c560f75
--- /dev/null
+++ 
b/storage/src/main/java/org/apache/kafka/storage/internals/checkpoint/LeaderEpochCheckpoint.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.storage.internals.checkpoint;
+
+import org.apache.kafka.server.log.internals.EpochEntry;
+
+import java.util.Collection;
+import java.util.List;
+
+public interface LeaderEpochCheckpoint {
+
+    void write(Collection<EpochEntry> epochs);
+
+    List<EpochEntry> read();
+}
diff --git 
a/storage/src/main/java/org/apache/kafka/storage/internals/checkpoint/LeaderEpochCheckpointFile.java
 
b/storage/src/main/java/org/apache/kafka/storage/internals/checkpoint/LeaderEpochCheckpointFile.java
new file mode 100644
index 00000000000..1ff08287205
--- /dev/null
+++ 
b/storage/src/main/java/org/apache/kafka/storage/internals/checkpoint/LeaderEpochCheckpointFile.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.storage.internals.checkpoint;
+
+import org.apache.kafka.server.common.CheckpointFile.EntryFormatter;
+import org.apache.kafka.server.log.internals.EpochEntry;
+import org.apache.kafka.server.log.internals.LogDirFailureChannel;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.List;
+import java.util.Optional;
+import java.util.regex.Pattern;
+
+/**
+ * This class persists a map of (LeaderEpoch => Offsets) to a file (for a 
certain replica)
+ * <p>
+ * The format in the LeaderEpoch checkpoint file is like this:
+ * -----checkpoint file begin------
+ * 0                <- LeaderEpochCheckpointFile.currentVersion
+ * 2                <- following entries size
+ * 0  1     <- the format is: leader_epoch(int32) start_offset(int64)
+ * 1  2
+ * -----checkpoint file end----------
+ */
+public class LeaderEpochCheckpointFile implements LeaderEpochCheckpoint {
+
+    public static final Formatter FORMATTER = new Formatter();
+
+    private static final String LEADER_EPOCH_CHECKPOINT_FILENAME = 
"leader-epoch-checkpoint";
+    private static final Pattern WHITE_SPACES_PATTERN = 
Pattern.compile("\\s+");
+    private static final int CURRENT_VERSION = 0;
+
+    private final CheckpointFileWithFailureHandler<EpochEntry> checkpoint;
+
+    public LeaderEpochCheckpointFile(File file, LogDirFailureChannel 
logDirFailureChannel) throws IOException {
+        checkpoint = new CheckpointFileWithFailureHandler<>(file, 
CURRENT_VERSION, FORMATTER, logDirFailureChannel, 
file.getParentFile().getParent());
+    }
+
+    public void write(Collection<EpochEntry> epochs) {
+        checkpoint.write(epochs);
+    }
+
+    public List<EpochEntry> read() {
+        return checkpoint.read();
+    }
+
+    public static File newFile(File dir) {
+        return new File(dir, LEADER_EPOCH_CHECKPOINT_FILENAME);
+    }
+
+    private static class Formatter implements EntryFormatter<EpochEntry> {
+
+        public String toString(EpochEntry entry) {
+            return entry.epoch + " " + entry.startOffset;
+        }
+
+        public Optional<EpochEntry> fromString(String line) {
+            String[] strings = WHITE_SPACES_PATTERN.split(line);
+            return (strings.length == 2) ? Optional.of(new 
EpochEntry(Integer.parseInt(strings[0]), Long.parseLong(strings[1]))) : 
Optional.empty();
+        }
+    }
+}
\ No newline at end of file
diff --git 
a/storage/src/main/java/org/apache/kafka/storage/internals/epoch/LeaderEpochFileCache.java
 
b/storage/src/main/java/org/apache/kafka/storage/internals/epoch/LeaderEpochFileCache.java
new file mode 100644
index 00000000000..7c6e82bf161
--- /dev/null
+++ 
b/storage/src/main/java/org/apache/kafka/storage/internals/epoch/LeaderEpochFileCache.java
@@ -0,0 +1,403 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.storage.internals.epoch;
+
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.server.log.internals.EpochEntry;
+import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpoint;
+import org.slf4j.Logger;
+
+import java.util.AbstractMap;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.OptionalInt;
+import java.util.TreeMap;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
+import java.util.function.Predicate;
+
+import static 
org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.UNDEFINED_EPOCH;
+import static 
org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.UNDEFINED_EPOCH_OFFSET;
+
+/**
+ * Represents a cache of (LeaderEpoch => Offset) mappings for a particular 
replica.
+ * <p>
+ * Leader Epoch = epoch assigned to each leader by the controller.
+ * Offset = offset of the first message in each epoch.
+ */
+public class LeaderEpochFileCache {
+    private final LeaderEpochCheckpoint checkpoint;
+    private final Logger log;
+
+    private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
+    private final TreeMap<Integer, EpochEntry> epochs = new TreeMap<>();
+
+    /**
+     * @param topicPartition the associated topic partition
+     * @param checkpoint     the checkpoint file
+     */
+    public LeaderEpochFileCache(TopicPartition topicPartition, 
LeaderEpochCheckpoint checkpoint) {
+        this.checkpoint = checkpoint;
+        LogContext logContext = new LogContext("[LeaderEpochCache " + 
topicPartition + "] ");
+        log = logContext.logger(LeaderEpochFileCache.class);
+        checkpoint.read().forEach(this::assign);
+    }
+
+    /**
+     * Assigns the supplied Leader Epoch to the supplied Offset
+     * Once the epoch is assigned it cannot be reassigned
+     */
+    public void assign(int epoch, long startOffset) {
+        EpochEntry entry = new EpochEntry(epoch, startOffset);
+        if (assign(entry)) {
+            log.debug("Appended new epoch entry {}. Cache now contains {} 
entries.", entry, epochs.size());
+            flush();
+        }
+    }
+
+    public void assign(List<EpochEntry> entries) {
+        entries.forEach(entry -> {
+            if (assign(entry)) {
+                log.debug("Appended new epoch entry {}. Cache now contains {} 
entries.", entry, epochs.size());
+            }
+        });
+        if (!entries.isEmpty()) flush();
+    }
+
+    private boolean isUpdateNeeded(EpochEntry entry) {
+        return latestEntry().map(epochEntry -> entry.epoch != epochEntry.epoch 
|| entry.startOffset < epochEntry.startOffset).orElse(true);
+    }
+
+    private boolean assign(EpochEntry entry) {
+        if (entry.epoch < 0 || entry.startOffset < 0) {
+            throw new IllegalArgumentException("Received invalid partition 
leader epoch entry " + entry);
+        }
+
+        // Check whether the append is needed before acquiring the write lock
+        // in order to avoid contention with readers in the common case
+        if (!isUpdateNeeded(entry)) return false;
+
+        lock.writeLock().lock();
+        try {
+            if (isUpdateNeeded(entry)) {
+                maybeTruncateNonMonotonicEntries(entry);
+                epochs.put(entry.epoch, entry);
+                return true;
+            } else {
+                return false;
+            }
+        } finally {
+            lock.writeLock().unlock();
+        }
+    }
+
+    /**
+     * Remove any entries which violate monotonicity prior to appending a new 
entry
+     */
+    private void maybeTruncateNonMonotonicEntries(EpochEntry newEntry) {
+        List<EpochEntry> removedEpochs = removeFromEnd(entry -> entry.epoch >= 
newEntry.epoch || entry.startOffset >= newEntry.startOffset);
+
+
+        if (removedEpochs.size() > 1 || (!removedEpochs.isEmpty() && 
removedEpochs.get(0).startOffset != newEntry.startOffset)) {
+
+            // Only log a warning if there were non-trivial removals. If the 
start offset of the new entry
+            // matches the start offset of the removed epoch, then no data has 
been written and the truncation
+            // is expected.
+            log.warn("New epoch entry {} caused truncation of conflicting 
entries {}. " + "Cache now contains {} entries.", newEntry, removedEpochs, 
epochs.size());
+        }
+    }
+
+    private List<EpochEntry> removeFromEnd(Predicate<EpochEntry> predicate) {
+        return 
removeWhileMatching(epochs.descendingMap().entrySet().iterator(), predicate);
+    }
+
+    private List<EpochEntry> removeFromStart(Predicate<EpochEntry> predicate) {
+        return removeWhileMatching(epochs.entrySet().iterator(), predicate);
+    }
+
+    private List<EpochEntry> removeWhileMatching(Iterator<Map.Entry<Integer, 
EpochEntry>> iterator, Predicate<EpochEntry> predicate) {
+        ArrayList<EpochEntry> removedEpochs = new ArrayList<>();
+
+        while (iterator.hasNext()) {
+            EpochEntry entry = iterator.next().getValue();
+            if (predicate.test(entry)) {
+                removedEpochs.add(entry);
+                iterator.remove();
+            } else {
+                return removedEpochs;
+            }
+        }
+
+        return removedEpochs;
+    }
+
+    public boolean nonEmpty() {
+        lock.readLock().lock();
+        try {
+            return !epochs.isEmpty();
+        } finally {
+            lock.readLock().unlock();
+        }
+    }
+
+    public Optional<EpochEntry> latestEntry() {
+        lock.readLock().lock();
+        try {
+            return 
Optional.ofNullable(epochs.lastEntry()).map(Map.Entry::getValue);
+        } finally {
+            lock.readLock().unlock();
+        }
+    }
+
+    /**
+     * Returns the current Leader Epoch if one exists. This is the latest epoch
+     * which has messages assigned to it.
+     */
+    public OptionalInt latestEpoch() {
+        Optional<EpochEntry> entry = latestEntry();
+        return entry.isPresent() ? OptionalInt.of(entry.get().epoch) : 
OptionalInt.empty();
+    }
+
+    public OptionalInt previousEpoch() {
+        lock.readLock().lock();
+        try {
+            Optional<Map.Entry<Integer, EpochEntry>> lowerEntry = 
latestEntry().flatMap(entry -> 
Optional.ofNullable(epochs.lowerEntry(entry.epoch)));
+            return lowerEntry.isPresent() ? 
OptionalInt.of(lowerEntry.get().getKey()) : OptionalInt.empty();
+        } finally {
+            lock.readLock().unlock();
+        }
+    }
+
+    /**
+     * Get the earliest cached entry if one exists.
+     */
+    public Optional<EpochEntry> earliestEntry() {
+        lock.readLock().lock();
+        try {
+            return Optional.ofNullable(epochs.firstEntry()).map(x -> 
x.getValue());
+        } finally {
+            lock.readLock().unlock();
+        }
+    }
+
+    public OptionalInt previousEpoch(int epoch) {
+        lock.readLock().lock();
+        try {
+            return toOptionalInt(epochs.lowerKey(epoch));
+        } finally {
+            lock.readLock().unlock();
+        }
+    }
+
+    public OptionalInt nextEpoch(int epoch) {
+        lock.readLock().lock();
+        try {
+            return toOptionalInt(epochs.higherKey(epoch));
+        } finally {
+            lock.readLock().unlock();
+        }
+    }
+
+    private static OptionalInt toOptionalInt(Integer value) {
+        return (value != null) ? OptionalInt.of(value) : OptionalInt.empty();
+    }
+
+    public Optional<EpochEntry> epochEntry(int epoch) {
+        lock.readLock().lock();
+        try {
+            return Optional.ofNullable(epochs.get(epoch));
+        } finally {
+            lock.readLock().unlock();
+        }
+    }
+
+    /**
+     * Returns the Leader Epoch and the End Offset for a requested Leader 
Epoch.
+     * <p>
+     * The Leader Epoch returned is the largest epoch less than or equal to 
the requested Leader
+     * Epoch. The End Offset is the end offset of this epoch, which is defined 
as the start offset
+     * of the first Leader Epoch larger than the Leader Epoch requested, or 
else the Log End
+     * Offset if the latest epoch was requested.
+     * <p>
+     * During the upgrade phase, where there are existing messages may not 
have a leader epoch,
+     * if requestedEpoch is < the first epoch cached, UNDEFINED_EPOCH_OFFSET 
will be returned
+     * so that the follower falls back to High Water Mark.
+     *
+     * @param requestedEpoch requested leader epoch
+     * @param logEndOffset   the existing Log End Offset
+     * @return found leader epoch and end offset
+     */
+    public Map.Entry<Integer, Long> endOffsetFor(int requestedEpoch, long 
logEndOffset) {
+        lock.readLock().lock();
+        try {
+            Map.Entry<Integer, Long> epochAndOffset = null;
+            if (requestedEpoch == UNDEFINED_EPOCH) {
+                // This may happen if a bootstrapping follower sends a request 
with undefined epoch or
+                // a follower is on the older message format where leader 
epochs are not recorded
+                epochAndOffset = new 
AbstractMap.SimpleImmutableEntry<>(UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET);
+            } else if (latestEpoch().isPresent() && latestEpoch().getAsInt() 
== requestedEpoch) {
+                // For the leader, the latest epoch is always the current 
leader epoch that is still being written to.
+                // Followers should not have any reason to query for the end 
offset of the current epoch, but a consumer
+                // might if it is verifying its committed offset following a 
group rebalance. In this case, we return
+                // the current log end offset which makes the truncation check 
work as expected.
+                epochAndOffset = new 
AbstractMap.SimpleImmutableEntry<>(requestedEpoch, logEndOffset);
+            } else {
+                Map.Entry<Integer, EpochEntry> higherEntry = 
epochs.higherEntry(requestedEpoch);
+                if (higherEntry == null) {
+                    // The requested epoch is larger than any known epoch. 
This case should never be hit because
+                    // the latest cached epoch is always the largest.
+                    epochAndOffset = new 
AbstractMap.SimpleImmutableEntry<>(UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET);
+                } else {
+                    Map.Entry<Integer, EpochEntry> floorEntry = 
epochs.floorEntry(requestedEpoch);
+                    if (floorEntry == null) {
+                        // The requested epoch is smaller than any known 
epoch, so we return the start offset of the first
+                        // known epoch which is larger than it. This may be 
inaccurate as there could have been
+                        // epochs in between, but the point is that the data 
has already been removed from the log
+                        // and we want to ensure that the follower can 
replicate correctly beginning from the leader's
+                        // start offset.
+                        epochAndOffset = new 
AbstractMap.SimpleImmutableEntry<>(requestedEpoch, 
higherEntry.getValue().startOffset);
+                    } else {
+                        // We have at least one previous epoch and one 
subsequent epoch. The result is the first
+                        // prior epoch and the starting offset of the first 
subsequent epoch.
+                        epochAndOffset = new 
AbstractMap.SimpleImmutableEntry<>(floorEntry.getValue().epoch, 
higherEntry.getValue().startOffset);
+                    }
+                }
+            }
+
+            if (log.isTraceEnabled())
+                log.trace("Processed end offset request for epoch {} and 
returning epoch {} with end offset {} from epoch cache of size {}}", 
requestedEpoch, epochAndOffset.getKey(), epochAndOffset.getValue(), 
epochs.size());
+
+            return epochAndOffset;
+        } finally {
+            lock.readLock().unlock();
+        }
+    }
+
+    /**
+     * Removes all epoch entries from the store with start offsets greater 
than or equal to the passed offset.
+     */
+    public void truncateFromEnd(long endOffset) {
+        lock.writeLock().lock();
+        try {
+            Optional<EpochEntry> epochEntry = latestEntry();
+            if (endOffset >= 0 && epochEntry.isPresent() && 
epochEntry.get().startOffset >= endOffset) {
+                List<EpochEntry> removedEntries = removeFromEnd(x -> 
x.startOffset >= endOffset);
+
+                flush();
+
+                log.debug("Cleared entries {} from epoch cache after 
truncating to end offset {}, leaving {} entries in the cache.", removedEntries, 
endOffset, epochs.size());
+            }
+        } finally {
+            lock.writeLock().unlock();
+        }
+    }
+
+    /**
+     * Clears old epoch entries. This method searches for the oldest epoch < 
offset, updates the saved epoch offset to
+     * be offset, then clears any previous epoch entries.
+     * <p>
+     * This method is exclusive: so truncateFromStart(6) will retain an entry 
at offset 6.
+     *
+     * @param startOffset the offset to clear up to
+     */
+    public void truncateFromStart(long startOffset) {
+        lock.writeLock().lock();
+        try {
+            List<EpochEntry> removedEntries = removeFromStart(entry -> 
entry.startOffset <= startOffset);
+
+            if (!removedEntries.isEmpty()) {
+                EpochEntry firstBeforeStartOffset = 
removedEntries.get(removedEntries.size() - 1);
+                EpochEntry updatedFirstEntry = new 
EpochEntry(firstBeforeStartOffset.epoch, startOffset);
+                epochs.put(updatedFirstEntry.epoch, updatedFirstEntry);
+
+                flush();
+
+                log.debug("Cleared entries {} and rewrote first entry {} after 
truncating to start offset {}, leaving {} in the cache.", removedEntries, 
updatedFirstEntry, startOffset, epochs.size());
+            }
+        } finally {
+            lock.writeLock().unlock();
+        }
+    }
+
+    public OptionalInt epochForOffset(long offset) {
+        lock.readLock().lock();
+        try {
+            OptionalInt previousEpoch = OptionalInt.empty();
+            for (EpochEntry epochEntry : epochs.values()) {
+                int epoch = epochEntry.epoch;
+                long startOffset = epochEntry.startOffset;
+
+                // Found the exact offset, return the respective epoch.
+                if (startOffset == offset) return OptionalInt.of(epoch);
+
+                // Return the previous found epoch as this epoch's 
start-offset is more than the target offset.
+                if (startOffset > offset) return previousEpoch;
+
+                previousEpoch = OptionalInt.of(epoch);
+            }
+
+            return previousEpoch;
+        } finally {
+            lock.readLock().unlock();
+        }
+    }
+
+    /**
+     * Delete all entries.
+     */
+    public void clearAndFlush() {
+        lock.writeLock().lock();
+        try {
+            epochs.clear();
+            flush();
+        } finally {
+            lock.writeLock().unlock();
+        }
+    }
+
+    public void clear() {
+        lock.writeLock().lock();
+        try {
+            epochs.clear();
+        } finally {
+            lock.writeLock().unlock();
+        }
+    }
+
+    // Visible for testing
+    public List<EpochEntry> epochEntries() {
+        lock.writeLock().lock();
+        try {
+            return new ArrayList<>(epochs.values());
+        } finally {
+            lock.writeLock().unlock();
+        }
+    }
+
+    private void flush() {
+        lock.readLock().lock();
+        try {
+            checkpoint.write(epochs.values());
+        } finally {
+            lock.readLock().unlock();
+        }
+    }
+
+}

Reply via email to