This is an automated email from the ASF dual-hosted git repository.

ijuma pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new cf7029c0264 KAFKA-13093: Log compaction should write new segments with 
record version v2 (KIP-724) (#18321)
cf7029c0264 is described below

commit cf7029c0264fd7f7b15c2e98acc874ec8c3403f2
Author: Ismael Juma <[email protected]>
AuthorDate: Thu Jan 9 09:37:23 2025 -0800

    KAFKA-13093: Log compaction should write new segments with record version 
v2 (KIP-724) (#18321)
    
    Convert v0/v1 record batches to v2 during compaction even if said record 
batches would be
    written with no change otherwise. A few important details:
    
    1. V0 compressed record batch with multiple records is converted into 
single V2 record batch
    2. V0 uncompressed records are converted into single record V2 record 
batches
    3. V0 records are converted to V2 records with timestampType set to 
`CreateTime` and the
    timestamp is `-1`.
    4. The `KAFKA-4298` workaround is no longer needed since the conversion to 
V2 fixes
    the issue too.
    5. Removed a log warning applicable to consumers older than 0.10.1 - they 
are no longer
    supported.
    6. Added back the ability to append records with v0/v1 (for testing only).
    7. The creation of the leader epoch cache is no longer optional since the 
record version
    config is effectively always V2.
    
    Add integration tests, these tests existed before #18267 - restored, 
modified and
    extended them.
    
    Reviewers: Jun Rao <[email protected]>
---
 .../apache/kafka/common/record/MemoryRecords.java  |  87 +++++-------
 .../internals/FetchRequestManagerTest.java         |   4 +-
 .../clients/consumer/internals/FetcherTest.java    |   4 +-
 .../internals/ShareConsumeRequestManagerTest.java  |   4 +-
 .../kafka/common/record/MemoryRecordsTest.java     |  43 +++---
 .../java/kafka/log/remote/RemoteLogManager.java    |  98 +++++---------
 .../main/java/kafka/server/TierStateMachine.java   |   4 +-
 core/src/main/scala/kafka/cluster/Partition.scala  |   2 +-
 core/src/main/scala/kafka/log/LogCleaner.scala     |   5 +-
 core/src/main/scala/kafka/log/UnifiedLog.scala     | 135 +++++++++----------
 .../main/scala/kafka/raft/KafkaMetadataLog.scala   |   2 +-
 .../scala/kafka/server/LocalLeaderEndPoint.scala   |   6 +-
 .../kafka/log/remote/RemoteLogManagerTest.java     |  78 ++++++-----
 .../unit/kafka/cluster/PartitionLockTest.scala     |   9 +-
 .../scala/unit/kafka/cluster/PartitionTest.scala   |  12 +-
 .../log/AbstractLogCleanerIntegrationTest.scala    |   6 +-
 .../unit/kafka/log/LogCleanerManagerTest.scala     |   7 +-
 .../LogCleanerParameterizedIntegrationTest.scala   | 150 ++++++++++++++++++++-
 .../test/scala/unit/kafka/log/LogCleanerTest.scala |   7 +-
 .../test/scala/unit/kafka/log/LogLoaderTest.scala  |  38 +++---
 .../test/scala/unit/kafka/log/LogTestUtils.scala   |   7 -
 .../test/scala/unit/kafka/log/UnifiedLogTest.scala |  74 +++++-----
 .../unit/kafka/server/ReplicaManagerTest.scala     |  13 +-
 .../scala/unit/kafka/utils/SchedulerTest.scala     |   8 +-
 .../kafka/storage/internals/log/LogLoader.java     |  12 +-
 .../kafka/storage/internals/log/LogSegment.java    |  12 +-
 .../storage/internals/log/LogSegmentTest.java      |  12 +-
 .../tiered/storage/TieredStorageTestContext.java   |   6 +-
 .../actions/ExpectLeaderEpochCheckpointAction.java |   8 +-
 29 files changed, 463 insertions(+), 390 deletions(-)

diff --git 
a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java 
b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
index 650071474db..3aee889aded 100644
--- a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
+++ b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
@@ -16,7 +16,6 @@
  */
 package org.apache.kafka.common.record;
 
-import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.compress.Compression;
 import org.apache.kafka.common.errors.CorruptRecordException;
 import org.apache.kafka.common.message.KRaftVersionRecord;
@@ -137,12 +136,8 @@ public class MemoryRecords extends AbstractRecords {
     /**
      * Filter the records into the provided ByteBuffer.
      *
-     * @param partition                   The partition that is filtered (used 
only for logging)
      * @param filter                      The filter function
      * @param destinationBuffer           The byte buffer to write the 
filtered records to
-     * @param maxRecordBatchSize          The maximum record batch size. Note 
this is not a hard limit: if a batch
-     *                                    exceeds this after filtering, we log 
a warning, but the batch will still be
-     *                                    created.
      * @param decompressionBufferSupplier The supplier of ByteBuffer(s) used 
for decompression if supported. For small
      *                                    record batches, allocating a 
potentially large buffer (64 KB for LZ4) will
      *                                    dominate the cost of decompressing 
and iterating over the records in the
@@ -150,18 +145,16 @@ public class MemoryRecords extends AbstractRecords {
      *                                    performance impact.
      * @return A FilterResult with a summary of the output (for metrics) and 
potentially an overflow buffer
      */
-    public FilterResult filterTo(TopicPartition partition, RecordFilter 
filter, ByteBuffer destinationBuffer,
-                                 int maxRecordBatchSize, BufferSupplier 
decompressionBufferSupplier) {
-        return filterTo(partition, batches(), filter, destinationBuffer, 
maxRecordBatchSize, decompressionBufferSupplier);
+    public FilterResult filterTo(RecordFilter filter, ByteBuffer 
destinationBuffer, BufferSupplier decompressionBufferSupplier) {
+        return filterTo(batches(), filter, destinationBuffer, 
decompressionBufferSupplier);
     }
 
     /**
      * Note: This method is also used to convert the first timestamp of the 
batch (which is usually the timestamp of the first record)
      * to the delete horizon of the tombstones or txn markers which are 
present in the batch.
      */
-    private static FilterResult filterTo(TopicPartition partition, 
Iterable<MutableRecordBatch> batches,
-                                         RecordFilter filter, ByteBuffer 
destinationBuffer, int maxRecordBatchSize,
-                                         BufferSupplier 
decompressionBufferSupplier) {
+    private static FilterResult filterTo(Iterable<MutableRecordBatch> batches, 
RecordFilter filter,
+                                         ByteBuffer destinationBuffer, 
BufferSupplier decompressionBufferSupplier) {
         FilterResult filterResult = new FilterResult(destinationBuffer);
         ByteBufferOutputStream bufferOutputStream = new 
ByteBufferOutputStream(destinationBuffer);
         for (MutableRecordBatch batch : batches) {
@@ -174,15 +167,9 @@ public class MemoryRecords extends AbstractRecords {
             if (batchRetention == BatchRetention.DELETE)
                 continue;
 
-            // We use the absolute offset to decide whether to retain the 
message or not. Due to KAFKA-4298, we have to
-            // allow for the possibility that a previous version corrupted the 
log by writing a compressed record batch
-            // with a magic value not matching the magic of the records (magic 
< 2). This will be fixed as we
-            // recopy the messages to the destination buffer.
-            byte batchMagic = batch.magic();
-            List<Record> retainedRecords = new ArrayList<>();
-
-            final BatchFilterResult iterationResult = filterBatch(batch, 
decompressionBufferSupplier, filterResult, filter,
-                    batchMagic, true, retainedRecords);
+            final BatchFilterResult iterationResult = filterBatch(batch, 
decompressionBufferSupplier, filterResult,
+                filter);
+            List<Record> retainedRecords = iterationResult.retainedRecords;
             boolean containsTombstones = iterationResult.containsTombstones;
             boolean writeOriginalBatch = iterationResult.writeOriginalBatch;
             long maxOffset = iterationResult.maxOffset;
@@ -191,8 +178,8 @@ public class MemoryRecords extends AbstractRecords {
                 // we check if the delete horizon should be set to a new value
                 // in which case, we need to reset the base timestamp and 
overwrite the timestamp deltas
                 // if the batch does not contain tombstones, then we don't 
need to overwrite batch
-                boolean needToSetDeleteHorizon = batch.magic() >= 
RecordBatch.MAGIC_VALUE_V2 && (containsTombstones || containsMarkerForEmptyTxn)
-                    && batch.deleteHorizonMs().isEmpty();
+                boolean needToSetDeleteHorizon = (containsTombstones || 
containsMarkerForEmptyTxn) &&
+                    batch.deleteHorizonMs().isEmpty();
                 if (writeOriginalBatch && !needToSetDeleteHorizon) {
                     batch.writeTo(bufferOutputStream);
                     filterResult.updateRetainedBatchMetadata(batch, 
retainedRecords.size(), false);
@@ -202,26 +189,21 @@ public class MemoryRecords extends AbstractRecords {
                         deleteHorizonMs = filter.currentTime + 
filter.deleteRetentionMs;
                     else
                         deleteHorizonMs = 
batch.deleteHorizonMs().orElse(RecordBatch.NO_TIMESTAMP);
-                    try (final MemoryRecordsBuilder builder = 
buildRetainedRecordsInto(batch, retainedRecords, bufferOutputStream, 
deleteHorizonMs)) {
+                    try (final MemoryRecordsBuilder builder = 
buildRetainedRecordsInto(batch, retainedRecords,
+                            bufferOutputStream, deleteHorizonMs)) {
                         MemoryRecords records = builder.build();
                         int filteredBatchSize = records.sizeInBytes();
-                        if (filteredBatchSize > batch.sizeInBytes() && 
filteredBatchSize > maxRecordBatchSize)
-                            log.warn("Record batch from {} with last offset {} 
exceeded max record batch size {} after cleaning " +
-                                    "(new size is {}). Consumers with version 
earlier than 0.10.1.0 may need to " +
-                                    "increase their fetch sizes.",
-                                partition, batch.lastOffset(), 
maxRecordBatchSize, filteredBatchSize);
-
                         MemoryRecordsBuilder.RecordsInfo info = builder.info();
                         
filterResult.updateRetainedBatchMetadata(info.maxTimestamp, 
info.shallowOffsetOfMaxTimestamp,
                             maxOffset, retainedRecords.size(), 
filteredBatchSize);
                     }
                 }
             } else if (batchRetention == BatchRetention.RETAIN_EMPTY) {
-                if (batchMagic < RecordBatch.MAGIC_VALUE_V2)
+                if (batch.magic() < RecordBatch.MAGIC_VALUE_V2) // should 
never happen
                     throw new IllegalStateException("Empty batches are only 
supported for magic v2 and above");
 
                 
bufferOutputStream.ensureRemaining(DefaultRecordBatch.RECORD_BATCH_OVERHEAD);
-                
DefaultRecordBatch.writeEmptyHeader(bufferOutputStream.buffer(), batchMagic, 
batch.producerId(),
+                
DefaultRecordBatch.writeEmptyHeader(bufferOutputStream.buffer(), 
RecordBatch.CURRENT_MAGIC_VALUE, batch.producerId(),
                         batch.producerEpoch(), batch.baseSequence(), 
batch.baseOffset(), batch.lastOffset(),
                         batch.partitionLeaderEpoch(), batch.timestampType(), 
batch.maxTimestamp(),
                         batch.isTransactional(), batch.isControlBatch());
@@ -243,23 +225,18 @@ public class MemoryRecords extends AbstractRecords {
     private static BatchFilterResult filterBatch(RecordBatch batch,
                                                  BufferSupplier 
decompressionBufferSupplier,
                                                  FilterResult filterResult,
-                                                 RecordFilter filter,
-                                                 byte batchMagic,
-                                                 boolean writeOriginalBatch,
-                                                 List<Record> retainedRecords) 
{
-        long maxOffset = -1;
-        boolean containsTombstones = false;
+                                                 RecordFilter filter) {
         try (final CloseableIterator<Record> iterator = 
batch.streamingIterator(decompressionBufferSupplier)) {
+            long maxOffset = -1;
+            boolean containsTombstones = false;
+            // Convert records with old record versions
+            boolean writeOriginalBatch = batch.magic() >= 
RecordBatch.CURRENT_MAGIC_VALUE;
+            List<Record> retainedRecords = new ArrayList<>();
             while (iterator.hasNext()) {
                 Record record = iterator.next();
                 filterResult.messagesRead += 1;
 
                 if (filter.shouldRetainRecord(batch, record)) {
-                    // Check for log corruption due to KAFKA-4298. If we find 
it, make sure that we overwrite
-                    // the corrupted batch with correct data.
-                    if (!record.hasMagic(batchMagic))
-                        writeOriginalBatch = false;
-
                     if (record.offset() > maxOffset)
                         maxOffset = record.offset();
 
@@ -272,17 +249,20 @@ public class MemoryRecords extends AbstractRecords {
                     writeOriginalBatch = false;
                 }
             }
-            return new BatchFilterResult(writeOriginalBatch, 
containsTombstones, maxOffset);
+            return new BatchFilterResult(retainedRecords, writeOriginalBatch, 
containsTombstones, maxOffset);
         }
     }
 
     private static class BatchFilterResult {
+        private final List<Record> retainedRecords;
         private final boolean writeOriginalBatch;
         private final boolean containsTombstones;
         private final long maxOffset;
-        private BatchFilterResult(final boolean writeOriginalBatch,
-                                 final boolean containsTombstones,
-                                 final long maxOffset) {
+        private BatchFilterResult(List<Record> retainedRecords,
+                                  final boolean writeOriginalBatch,
+                                  final boolean containsTombstones,
+                                  final long maxOffset) {
+            this.retainedRecords = retainedRecords;
             this.writeOriginalBatch = writeOriginalBatch;
             this.containsTombstones = containsTombstones;
             this.maxOffset = maxOffset;
@@ -293,15 +273,20 @@ public class MemoryRecords extends AbstractRecords {
                                                                  List<Record> 
retainedRecords,
                                                                  
ByteBufferOutputStream bufferOutputStream,
                                                                  final long 
deleteHorizonMs) {
-        byte magic = originalBatch.magic();
         Compression compression = 
Compression.of(originalBatch.compressionType()).build();
-        TimestampType timestampType = originalBatch.timestampType();
+        // V0 has no timestamp type or timestamp, so we set the timestamp to 
CREATE_TIME and timestamp to NO_TIMESTAMP.
+        // Note that this differs from produce up-conversion where the 
timestamp type topic config is used and the log append
+        // time is generated if the config is LOG_APPEND_TIME. The reason for 
the different behavior is that there is
+        // no appropriate log append time we can generate at compaction time.
+        TimestampType timestampType = originalBatch.timestampType() == 
TimestampType.NO_TIMESTAMP_TYPE ?
+                TimestampType.CREATE_TIME : originalBatch.timestampType();
         long logAppendTime = timestampType == TimestampType.LOG_APPEND_TIME ?
                 originalBatch.maxTimestamp() : RecordBatch.NO_TIMESTAMP;
-        long baseOffset = magic >= RecordBatch.MAGIC_VALUE_V2 ?
+        long baseOffset = originalBatch.magic() >= RecordBatch.MAGIC_VALUE_V2 ?
                 originalBatch.baseOffset() : retainedRecords.get(0).offset();
 
-        MemoryRecordsBuilder builder = new 
MemoryRecordsBuilder(bufferOutputStream, magic,
+        // Convert records with older record versions to the current one
+        MemoryRecordsBuilder builder = new 
MemoryRecordsBuilder(bufferOutputStream, RecordBatch.CURRENT_MAGIC_VALUE,
                 compression, timestampType, baseOffset, logAppendTime, 
originalBatch.producerId(),
                 originalBatch.producerEpoch(), originalBatch.baseSequence(), 
originalBatch.isTransactional(),
                 originalBatch.isControlBatch(), 
originalBatch.partitionLeaderEpoch(), bufferOutputStream.limit(), 
deleteHorizonMs);
@@ -309,7 +294,7 @@ public class MemoryRecords extends AbstractRecords {
         for (Record record : retainedRecords)
             builder.append(record);
 
-        if (magic >= RecordBatch.MAGIC_VALUE_V2)
+        if (originalBatch.magic() >= RecordBatch.MAGIC_VALUE_V2)
             // we must preserve the last offset from the initial batch in 
order to ensure that the
             // last sequence number from the batch remains even after 
compaction. Otherwise, the producer
             // could incorrectly see an out of sequence error.
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetchRequestManagerTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetchRequestManagerTest.java
index 61ea2e8e565..8657dcfc1e9 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetchRequestManagerTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetchRequestManagerTest.java
@@ -2532,7 +2532,7 @@ public class FetchRequestManagerTest {
                 new SimpleRecord(null, "value".getBytes()));
 
         // Remove the last record to simulate compaction
-        MemoryRecords.FilterResult result = records.filterTo(tp0, new 
MemoryRecords.RecordFilter(0, 0) {
+        MemoryRecords.FilterResult result = records.filterTo(new 
MemoryRecords.RecordFilter(0, 0) {
             @Override
             protected BatchRetentionResult checkBatchRetention(RecordBatch 
batch) {
                 return new BatchRetentionResult(BatchRetention.DELETE_EMPTY, 
false);
@@ -2542,7 +2542,7 @@ public class FetchRequestManagerTest {
             protected boolean shouldRetainRecord(RecordBatch recordBatch, 
Record record) {
                 return record.key() != null;
             }
-        }, ByteBuffer.allocate(1024), Integer.MAX_VALUE, 
BufferSupplier.NO_CACHING);
+        }, ByteBuffer.allocate(1024), BufferSupplier.NO_CACHING);
         result.outputBuffer().flip();
         MemoryRecords compactedRecords = 
MemoryRecords.readableRecords(result.outputBuffer());
 
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
index ab6a9a0c91d..ede973c5f9b 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
@@ -2518,7 +2518,7 @@ public class FetcherTest {
                 new SimpleRecord(null, "value".getBytes()));
 
         // Remove the last record to simulate compaction
-        MemoryRecords.FilterResult result = records.filterTo(tp0, new 
MemoryRecords.RecordFilter(0, 0) {
+        MemoryRecords.FilterResult result = records.filterTo(new 
MemoryRecords.RecordFilter(0, 0) {
             @Override
             protected BatchRetentionResult checkBatchRetention(RecordBatch 
batch) {
                 return new BatchRetentionResult(BatchRetention.DELETE_EMPTY, 
false);
@@ -2528,7 +2528,7 @@ public class FetcherTest {
             protected boolean shouldRetainRecord(RecordBatch recordBatch, 
Record record) {
                 return record.key() != null;
             }
-        }, ByteBuffer.allocate(1024), Integer.MAX_VALUE, 
BufferSupplier.NO_CACHING);
+        }, ByteBuffer.allocate(1024), BufferSupplier.NO_CACHING);
         result.outputBuffer().flip();
         MemoryRecords compactedRecords = 
MemoryRecords.readableRecords(result.outputBuffer());
 
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ShareConsumeRequestManagerTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ShareConsumeRequestManagerTest.java
index 640eadc0e77..220483cf22d 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ShareConsumeRequestManagerTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ShareConsumeRequestManagerTest.java
@@ -1334,7 +1334,7 @@ public class ShareConsumeRequestManagerTest {
                 new SimpleRecord(null, "value".getBytes()));
 
         // Remove the last record to simulate compaction
-        MemoryRecords.FilterResult result = records.filterTo(tp0, new 
MemoryRecords.RecordFilter(0, 0) {
+        MemoryRecords.FilterResult result = records.filterTo(new 
MemoryRecords.RecordFilter(0, 0) {
             @Override
             protected BatchRetentionResult checkBatchRetention(RecordBatch 
batch) {
                 return new BatchRetentionResult(BatchRetention.DELETE_EMPTY, 
false);
@@ -1344,7 +1344,7 @@ public class ShareConsumeRequestManagerTest {
             protected boolean shouldRetainRecord(RecordBatch recordBatch, 
Record record) {
                 return record.key() != null;
             }
-        }, ByteBuffer.allocate(1024), Integer.MAX_VALUE, 
BufferSupplier.NO_CACHING);
+        }, ByteBuffer.allocate(1024), BufferSupplier.NO_CACHING);
         result.outputBuffer().flip();
         MemoryRecords compactedRecords = 
MemoryRecords.readableRecords(result.outputBuffer());
 
diff --git 
a/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsTest.java 
b/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsTest.java
index 80a77d647b4..3818976e423 100644
--- 
a/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsTest.java
@@ -16,7 +16,6 @@
  */
 package org.apache.kafka.common.record;
 
-import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.compress.Compression;
 import org.apache.kafka.common.errors.CorruptRecordException;
 import org.apache.kafka.common.header.internals.RecordHeaders;
@@ -291,8 +290,7 @@ public class MemoryRecordsTest {
         builder.append(12L, null, "c".getBytes());
 
         ByteBuffer filtered = ByteBuffer.allocate(2048);
-        builder.build().filterTo(new TopicPartition("foo", 0), new 
RetainNonNullKeysFilter(), filtered,
-                Integer.MAX_VALUE, BufferSupplier.NO_CACHING);
+        builder.build().filterTo(new RetainNonNullKeysFilter(), filtered, 
BufferSupplier.NO_CACHING);
 
         filtered.flip();
         MemoryRecords filteredRecords = 
MemoryRecords.readableRecords(filtered);
@@ -332,7 +330,7 @@ public class MemoryRecordsTest {
                     builder.close();
                     MemoryRecords records = builder.build();
                     ByteBuffer filtered = ByteBuffer.allocate(2048);
-                    MemoryRecords.FilterResult filterResult = 
records.filterTo(new TopicPartition("foo", 0),
+                    MemoryRecords.FilterResult filterResult = records.filterTo(
                             new MemoryRecords.RecordFilter(0, 0) {
                                 @Override
                                 protected BatchRetentionResult 
checkBatchRetention(RecordBatch batch) {
@@ -345,7 +343,7 @@ public class MemoryRecordsTest {
                                     // delete the records
                                     return false;
                                 }
-                            }, filtered, Integer.MAX_VALUE, 
BufferSupplier.NO_CACHING);
+                            }, filtered, BufferSupplier.NO_CACHING);
 
                     // Verify filter result
                     assertEquals(numRecords, filterResult.messagesRead());
@@ -394,7 +392,7 @@ public class MemoryRecordsTest {
 
         ByteBuffer filtered = ByteBuffer.allocate(2048);
         MemoryRecords records = MemoryRecords.readableRecords(buffer);
-        MemoryRecords.FilterResult filterResult = records.filterTo(new 
TopicPartition("foo", 0),
+        MemoryRecords.FilterResult filterResult = records.filterTo(
                 new MemoryRecords.RecordFilter(0, 0) {
                     @Override
                     protected BatchRetentionResult 
checkBatchRetention(RecordBatch batch) {
@@ -406,7 +404,7 @@ public class MemoryRecordsTest {
                     protected boolean shouldRetainRecord(RecordBatch 
recordBatch, Record record) {
                         return false;
                     }
-                }, filtered, Integer.MAX_VALUE, BufferSupplier.NO_CACHING);
+                }, filtered, BufferSupplier.NO_CACHING);
 
         // Verify filter result
         assertEquals(0, filterResult.messagesRead());
@@ -442,7 +440,7 @@ public class MemoryRecordsTest {
 
             ByteBuffer filtered = ByteBuffer.allocate(2048);
             MemoryRecords records = MemoryRecords.readableRecords(buffer);
-            MemoryRecords.FilterResult filterResult = records.filterTo(new 
TopicPartition("foo", 0),
+            MemoryRecords.FilterResult filterResult = records.filterTo(
                     new MemoryRecords.RecordFilter(0, 0) {
                         @Override
                         protected BatchRetentionResult 
checkBatchRetention(RecordBatch batch) {
@@ -453,7 +451,7 @@ public class MemoryRecordsTest {
                         protected boolean shouldRetainRecord(RecordBatch 
recordBatch, Record record) {
                             return false;
                         }
-                    }, filtered, Integer.MAX_VALUE, BufferSupplier.NO_CACHING);
+                    }, filtered, BufferSupplier.NO_CACHING);
 
             // Verify filter result
             assertEquals(0, filterResult.outputBuffer().position());
@@ -529,7 +527,7 @@ public class MemoryRecordsTest {
                 return new BatchRetentionResult(BatchRetention.RETAIN_EMPTY, 
false);
             }
         };
-        builder.build().filterTo(new TopicPartition("random", 0), 
recordFilter, filtered, Integer.MAX_VALUE, BufferSupplier.NO_CACHING);
+        builder.build().filterTo(recordFilter, filtered, 
BufferSupplier.NO_CACHING);
         filtered.flip();
         MemoryRecords filteredRecords = 
MemoryRecords.readableRecords(filtered);
 
@@ -618,7 +616,7 @@ public class MemoryRecordsTest {
         buffer.flip();
 
         ByteBuffer filtered = ByteBuffer.allocate(2048);
-        MemoryRecords.readableRecords(buffer).filterTo(new 
TopicPartition("foo", 0), new MemoryRecords.RecordFilter(0, 0) {
+        MemoryRecords.readableRecords(buffer).filterTo(new 
MemoryRecords.RecordFilter(0, 0) {
             @Override
             protected BatchRetentionResult checkBatchRetention(RecordBatch 
batch) {
                 // discard the second and fourth batches
@@ -631,7 +629,7 @@ public class MemoryRecordsTest {
             protected boolean shouldRetainRecord(RecordBatch recordBatch, 
Record record) {
                 return true;
             }
-        }, filtered, Integer.MAX_VALUE, BufferSupplier.NO_CACHING);
+        }, filtered, BufferSupplier.NO_CACHING);
 
         filtered.flip();
         MemoryRecords filteredRecords = 
MemoryRecords.readableRecords(filtered);
@@ -667,8 +665,7 @@ public class MemoryRecordsTest {
         buffer.flip();
 
         ByteBuffer filtered = ByteBuffer.allocate(2048);
-        MemoryRecords.readableRecords(buffer).filterTo(new 
TopicPartition("foo", 0), new RetainNonNullKeysFilter(),
-                filtered, Integer.MAX_VALUE, BufferSupplier.NO_CACHING);
+        MemoryRecords.readableRecords(buffer).filterTo(new 
RetainNonNullKeysFilter(), filtered, BufferSupplier.NO_CACHING);
         filtered.flip();
         MemoryRecords filteredRecords = 
MemoryRecords.readableRecords(filtered);
 
@@ -743,8 +740,7 @@ public class MemoryRecordsTest {
             buffer.flip();
 
             ByteBuffer filtered = ByteBuffer.allocate(2048);
-            MemoryRecords.readableRecords(buffer).filterTo(new 
TopicPartition("foo", 0), new RetainNonNullKeysFilter(),
-                    filtered, Integer.MAX_VALUE, BufferSupplier.NO_CACHING);
+            MemoryRecords.readableRecords(buffer).filterTo(new 
RetainNonNullKeysFilter(), filtered, BufferSupplier.NO_CACHING);
 
             filtered.flip();
             MemoryRecords filteredRecords = 
MemoryRecords.readableRecords(filtered);
@@ -835,9 +831,8 @@ public class MemoryRecordsTest {
         while (buffer.hasRemaining()) {
             output.rewind();
 
-            MemoryRecords.FilterResult result = 
MemoryRecords.readableRecords(buffer)
-                    .filterTo(new TopicPartition("foo", 0), new 
RetainNonNullKeysFilter(), output, Integer.MAX_VALUE,
-                            BufferSupplier.NO_CACHING);
+            MemoryRecords.FilterResult result = 
MemoryRecords.readableRecords(buffer).filterTo(
+                    new RetainNonNullKeysFilter(), output, 
BufferSupplier.NO_CACHING);
 
             buffer.position(buffer.position() + result.bytesRead());
             result.outputBuffer().flip();
@@ -884,8 +879,7 @@ public class MemoryRecordsTest {
 
         ByteBuffer filtered = ByteBuffer.allocate(2048);
         MemoryRecords.FilterResult result = 
MemoryRecords.readableRecords(buffer).filterTo(
-                new TopicPartition("foo", 0), new RetainNonNullKeysFilter(), 
filtered, Integer.MAX_VALUE,
-                BufferSupplier.NO_CACHING);
+                new RetainNonNullKeysFilter(), filtered, 
BufferSupplier.NO_CACHING);
 
         filtered.flip();
 
@@ -928,14 +922,14 @@ public class MemoryRecordsTest {
             RecordBatch batch = batches.get(i);
             assertEquals(expectedStartOffsets.get(i).longValue(), 
batch.baseOffset());
             assertEquals(expectedEndOffsets.get(i).longValue(), 
batch.lastOffset());
-            assertEquals(magic, batch.magic());
+            assertEquals(RecordBatch.CURRENT_MAGIC_VALUE, batch.magic());
             assertEquals(compression.type(), batch.compressionType());
             if (magic >= RecordBatch.MAGIC_VALUE_V1) {
                 assertEquals(expectedMaxTimestamps.get(i).longValue(), 
batch.maxTimestamp());
                 assertEquals(TimestampType.CREATE_TIME, batch.timestampType());
             } else {
                 assertEquals(RecordBatch.NO_TIMESTAMP, batch.maxTimestamp());
-                assertEquals(TimestampType.NO_TIMESTAMP_TYPE, 
batch.timestampType());
+                assertEquals(TimestampType.CREATE_TIME, batch.timestampType());
             }
         }
 
@@ -1003,8 +997,7 @@ public class MemoryRecordsTest {
         buffer.flip();
 
         ByteBuffer filtered = ByteBuffer.allocate(2048);
-        MemoryRecords.readableRecords(buffer).filterTo(new 
TopicPartition("foo", 0), new RetainNonNullKeysFilter(),
-                filtered, Integer.MAX_VALUE, BufferSupplier.NO_CACHING);
+        MemoryRecords.readableRecords(buffer).filterTo(new 
RetainNonNullKeysFilter(), filtered, BufferSupplier.NO_CACHING);
 
         filtered.flip();
         MemoryRecords filteredRecords = 
MemoryRecords.readableRecords(filtered);
diff --git a/core/src/main/java/kafka/log/remote/RemoteLogManager.java 
b/core/src/main/java/kafka/log/remote/RemoteLogManager.java
index bcdd718baa1..b5f9e408c94 100644
--- a/core/src/main/java/kafka/log/remote/RemoteLogManager.java
+++ b/core/src/main/java/kafka/log/remote/RemoteLogManager.java
@@ -777,11 +777,7 @@ public class RemoteLogManager implements Closeable {
      * @return the leader epoch entries
      */
     List<EpochEntry> getLeaderEpochEntries(UnifiedLog log, long startOffset, 
long endOffset) {
-        if (log.leaderEpochCache().isDefined()) {
-            return 
log.leaderEpochCache().get().epochEntriesInRange(startOffset, endOffset);
-        } else {
-            return Collections.emptyList();
-        }
+        return log.leaderEpochCache().epochEntriesInRange(startOffset, 
endOffset);
     }
 
     // VisibleForTesting
@@ -1249,11 +1245,6 @@ public class RemoteLogManager implements Closeable {
             }
 
             final UnifiedLog log = logOptional.get();
-            final Option<LeaderEpochFileCache> leaderEpochCacheOption = 
log.leaderEpochCache();
-            if (leaderEpochCacheOption.isEmpty()) {
-                logger.debug("No leader epoch cache available for partition: 
{}", topicIdPartition);
-                return;
-            }
 
             // Cleanup remote log segments and update the log start offset if 
applicable.
             final Iterator<RemoteLogSegmentMetadata> segmentMetadataIter = 
remoteLogMetadataManager.listRemoteLogSegments(topicIdPartition);
@@ -1281,7 +1272,7 @@ public class RemoteLogManager implements Closeable {
             final List<Integer> remoteLeaderEpochs = new 
ArrayList<>(epochsSet);
             Collections.sort(remoteLeaderEpochs);
 
-            LeaderEpochFileCache leaderEpochCache = 
leaderEpochCacheOption.get();
+            LeaderEpochFileCache leaderEpochCache = log.leaderEpochCache();
             // Build the leader epoch map by filtering the epochs that do not 
have any records.
             NavigableMap<Integer, Long> epochWithOffsets = 
buildFilteredLeaderEpochMap(leaderEpochCache.epochWithOffsets());
 
@@ -1680,10 +1671,8 @@ public class RemoteLogManager implements Closeable {
         OptionalInt epoch = OptionalInt.empty();
 
         if (logOptional.isPresent()) {
-            Option<LeaderEpochFileCache> leaderEpochCache = 
logOptional.get().leaderEpochCache();
-            if (leaderEpochCache != null && leaderEpochCache.isDefined()) {
-                epoch = leaderEpochCache.get().epochForOffset(offset);
-            }
+            LeaderEpochFileCache leaderEpochCache = 
logOptional.get().leaderEpochCache();
+            epoch = leaderEpochCache.epochForOffset(offset);
         }
 
         Optional<RemoteLogSegmentMetadata> rlsMetadataOptional = 
epoch.isPresent()
@@ -1819,7 +1808,7 @@ public class RemoteLogManager implements Closeable {
                                             UnifiedLog log) throws 
RemoteStorageException {
         TopicPartition tp = 
segmentMetadata.topicIdPartition().topicPartition();
         boolean isSearchComplete = false;
-        LeaderEpochFileCache leaderEpochCache = 
log.leaderEpochCache().getOrElse(null);
+        LeaderEpochFileCache leaderEpochCache = log.leaderEpochCache();
         Optional<RemoteLogSegmentMetadata> currentMetadataOpt = 
Optional.of(segmentMetadata);
         while (!isSearchComplete && currentMetadataOpt.isPresent()) {
             RemoteLogSegmentMetadata currentMetadata = 
currentMetadataOpt.get();
@@ -1866,13 +1855,9 @@ public class RemoteLogManager implements Closeable {
 
     // visible for testing.
     Optional<RemoteLogSegmentMetadata> 
findNextSegmentMetadata(RemoteLogSegmentMetadata segmentMetadata,
-                                                               
Option<LeaderEpochFileCache> leaderEpochFileCacheOption) throws 
RemoteStorageException {
-        if (leaderEpochFileCacheOption.isEmpty()) {
-            return Optional.empty();
-        }
-
+                                                               
LeaderEpochFileCache leaderEpochFileCacheOption) throws RemoteStorageException {
         long nextSegmentBaseOffset = segmentMetadata.endOffset() + 1;
-        OptionalInt epoch = 
leaderEpochFileCacheOption.get().epochForOffset(nextSegmentBaseOffset);
+        OptionalInt epoch = 
leaderEpochFileCacheOption.epochForOffset(nextSegmentBaseOffset);
         return epoch.isPresent()
                 ? 
fetchRemoteLogSegmentMetadata(segmentMetadata.topicIdPartition().topicPartition(),
 epoch.getAsInt(), nextSegmentBaseOffset)
                 : Optional.empty();
@@ -1887,7 +1872,7 @@ public class RemoteLogManager implements Closeable {
      * Visible for testing
      * @param tp The topic partition.
      * @param offset The offset to start the search.
-     * @param leaderEpochCache The leader epoch file cache, this could be null.
+     * @param leaderEpochCache The leader epoch file cache.
      * @return The next segment metadata that contains the transaction index. 
The transaction index may or may not exist
      * in that segment metadata which depends on the RLMM plugin 
implementation. The caller of this method should handle
      * for both the cases.
@@ -1896,9 +1881,6 @@ public class RemoteLogManager implements Closeable {
     Optional<RemoteLogSegmentMetadata> 
findNextSegmentWithTxnIndex(TopicPartition tp,
                                                                    long offset,
                                                                    
LeaderEpochFileCache leaderEpochCache) throws RemoteStorageException {
-        if (leaderEpochCache == null) {
-            return Optional.empty();
-        }
         OptionalInt initialEpochOpt = leaderEpochCache.epochForOffset(offset);
         if (initialEpochOpt.isEmpty()) {
             return Optional.empty();
@@ -1933,30 +1915,27 @@ public class RemoteLogManager implements Closeable {
 
     OffsetAndEpoch findHighestRemoteOffset(TopicIdPartition topicIdPartition, 
UnifiedLog log) throws RemoteStorageException {
         OffsetAndEpoch offsetAndEpoch = null;
-        Option<LeaderEpochFileCache> leaderEpochCacheOpt = 
log.leaderEpochCache();
-        if (leaderEpochCacheOpt.isDefined()) {
-            LeaderEpochFileCache cache = leaderEpochCacheOpt.get();
-            Optional<EpochEntry> maybeEpochEntry = cache.latestEntry();
-            while (offsetAndEpoch == null && maybeEpochEntry.isPresent()) {
-                int epoch = maybeEpochEntry.get().epoch;
-                Optional<Long> highestRemoteOffsetOpt =
-                        
remoteLogMetadataManager.highestOffsetForEpoch(topicIdPartition, epoch);
-                if (highestRemoteOffsetOpt.isPresent()) {
-                    Map.Entry<Integer, Long> entry = cache.endOffsetFor(epoch, 
log.logEndOffset());
-                    int requestedEpoch = entry.getKey();
-                    long endOffset = entry.getValue();
-                    long highestRemoteOffset = highestRemoteOffsetOpt.get();
-                    if (endOffset <= highestRemoteOffset) {
-                        LOGGER.info("The end-offset for epoch {}: ({}, {}) is 
less than or equal to the " +
-                                "highest-remote-offset: {} for partition: {}", 
epoch, requestedEpoch, endOffset,
-                                highestRemoteOffset, topicIdPartition);
-                        offsetAndEpoch = new OffsetAndEpoch(endOffset - 1, 
requestedEpoch);
-                    } else {
-                        offsetAndEpoch = new 
OffsetAndEpoch(highestRemoteOffset, epoch);
-                    }
+        LeaderEpochFileCache leaderEpochCache = log.leaderEpochCache();
+        Optional<EpochEntry> maybeEpochEntry = leaderEpochCache.latestEntry();
+        while (offsetAndEpoch == null && maybeEpochEntry.isPresent()) {
+            int epoch = maybeEpochEntry.get().epoch;
+            Optional<Long> highestRemoteOffsetOpt =
+                    
remoteLogMetadataManager.highestOffsetForEpoch(topicIdPartition, epoch);
+            if (highestRemoteOffsetOpt.isPresent()) {
+                Map.Entry<Integer, Long> entry = 
leaderEpochCache.endOffsetFor(epoch, log.logEndOffset());
+                int requestedEpoch = entry.getKey();
+                long endOffset = entry.getValue();
+                long highestRemoteOffset = highestRemoteOffsetOpt.get();
+                if (endOffset <= highestRemoteOffset) {
+                    LOGGER.info("The end-offset for epoch {}: ({}, {}) is less 
than or equal to the " +
+                            "highest-remote-offset: {} for partition: {}", 
epoch, requestedEpoch, endOffset,
+                            highestRemoteOffset, topicIdPartition);
+                    offsetAndEpoch = new OffsetAndEpoch(endOffset - 1, 
requestedEpoch);
+                } else {
+                    offsetAndEpoch = new OffsetAndEpoch(highestRemoteOffset, 
epoch);
                 }
-                maybeEpochEntry = cache.previousEntry(epoch);
             }
+            maybeEpochEntry = leaderEpochCache.previousEntry(epoch);
         }
         if (offsetAndEpoch == null) {
             offsetAndEpoch = new OffsetAndEpoch(-1L, 
RecordBatch.NO_PARTITION_LEADER_EPOCH);
@@ -1966,20 +1945,17 @@ public class RemoteLogManager implements Closeable {
 
     long findLogStartOffset(TopicIdPartition topicIdPartition, UnifiedLog log) 
throws RemoteStorageException {
         Optional<Long> logStartOffset = Optional.empty();
-        Option<LeaderEpochFileCache> maybeLeaderEpochFileCache = 
log.leaderEpochCache();
-        if (maybeLeaderEpochFileCache.isDefined()) {
-            LeaderEpochFileCache cache = maybeLeaderEpochFileCache.get();
-            OptionalInt earliestEpochOpt = cache.earliestEntry()
-                    .map(epochEntry -> OptionalInt.of(epochEntry.epoch))
-                    .orElseGet(OptionalInt::empty);
-            while (logStartOffset.isEmpty() && earliestEpochOpt.isPresent()) {
-                Iterator<RemoteLogSegmentMetadata> iterator =
-                        
remoteLogMetadataManager.listRemoteLogSegments(topicIdPartition, 
earliestEpochOpt.getAsInt());
-                if (iterator.hasNext()) {
-                    logStartOffset = 
Optional.of(iterator.next().startOffset());
-                }
-                earliestEpochOpt = 
cache.nextEpoch(earliestEpochOpt.getAsInt());
+        LeaderEpochFileCache leaderEpochCache = log.leaderEpochCache();
+        OptionalInt earliestEpochOpt = leaderEpochCache.earliestEntry()
+                .map(epochEntry -> OptionalInt.of(epochEntry.epoch))
+                .orElseGet(OptionalInt::empty);
+        while (logStartOffset.isEmpty() && earliestEpochOpt.isPresent()) {
+            Iterator<RemoteLogSegmentMetadata> iterator =
+                    
remoteLogMetadataManager.listRemoteLogSegments(topicIdPartition, 
earliestEpochOpt.getAsInt());
+            if (iterator.hasNext()) {
+                logStartOffset = Optional.of(iterator.next().startOffset());
             }
+            earliestEpochOpt = 
leaderEpochCache.nextEpoch(earliestEpochOpt.getAsInt());
         }
         return logStartOffset.orElseGet(log::localLogStartOffset);
     }
diff --git a/core/src/main/java/kafka/server/TierStateMachine.java 
b/core/src/main/java/kafka/server/TierStateMachine.java
index ddb19e86aec..d316e70da2e 100644
--- a/core/src/main/java/kafka/server/TierStateMachine.java
+++ b/core/src/main/java/kafka/server/TierStateMachine.java
@@ -247,9 +247,7 @@ public class TierStateMachine {
 
         // Build leader epoch cache.
         List<EpochEntry> epochs = readLeaderEpochCheckpoint(rlm, 
remoteLogSegmentMetadata);
-        if (unifiedLog.leaderEpochCache().isDefined()) {
-            unifiedLog.leaderEpochCache().get().assign(epochs);
-        }
+        unifiedLog.leaderEpochCache().assign(epochs);
 
         log.info("Updated the epoch cache from remote tier till offset: {} 
with size: {} for {}", leaderLocalLogStartOffset, epochs.size(), partition);
 
diff --git a/core/src/main/scala/kafka/cluster/Partition.scala 
b/core/src/main/scala/kafka/cluster/Partition.scala
index e58f5824e76..e4d2f01cdd6 100755
--- a/core/src/main/scala/kafka/cluster/Partition.scala
+++ b/core/src/main/scala/kafka/cluster/Partition.scala
@@ -806,7 +806,7 @@ class Partition(val topicPartition: TopicPartition,
         // to ensure that these followers can truncate to the right offset, we 
must cache the new
         // leader epoch and the start offset since it should be larger than 
any epoch that a follower
         // would try to query.
-        leaderLog.maybeAssignEpochStartOffset(partitionState.leaderEpoch, 
leaderEpochStartOffset)
+        leaderLog.assignEpochStartOffset(partitionState.leaderEpoch, 
leaderEpochStartOffset)
 
         // Initialize lastCaughtUpTime of replicas as well as their 
lastFetchTimeMs and
         // lastFetchLeaderLogEndOffset.
diff --git a/core/src/main/scala/kafka/log/LogCleaner.scala 
b/core/src/main/scala/kafka/log/LogCleaner.scala
index 4f8d545be60..43193016fd0 100644
--- a/core/src/main/scala/kafka/log/LogCleaner.scala
+++ b/core/src/main/scala/kafka/log/LogCleaner.scala
@@ -684,7 +684,8 @@ private[log] class Cleaner(val id: Int,
 
         try {
           cleanInto(log.topicPartition, currentSegment.log, cleaned, map, 
retainLegacyDeletesAndTxnMarkers, log.config.deleteRetentionMs,
-            log.config.maxMessageSize, transactionMetadata, 
lastOffsetOfActiveProducers, upperBoundOffsetOfCleaningRound, stats, 
currentTime = currentTime)
+            log.config.maxMessageSize, transactionMetadata, 
lastOffsetOfActiveProducers,
+            upperBoundOffsetOfCleaningRound, stats, currentTime = currentTime)
         } catch {
           case e: LogSegmentOffsetOverflowException =>
             // Split the current segment. It's also safest to abort the 
current cleaning process, so that we retry from
@@ -810,7 +811,7 @@ private[log] class Cleaner(val id: Int,
       sourceRecords.readInto(readBuffer, position)
       val records = MemoryRecords.readableRecords(readBuffer)
       throttler.maybeThrottle(records.sizeInBytes)
-      val result = records.filterTo(topicPartition, logCleanerFilter, 
writeBuffer, maxLogMessageSize, decompressionBufferSupplier)
+      val result = records.filterTo(logCleanerFilter, writeBuffer, 
decompressionBufferSupplier)
 
       stats.readMessages(result.messagesRead, result.bytesRead)
       stats.recopyMessages(result.messagesRetained, result.bytesRetained)
diff --git a/core/src/main/scala/kafka/log/UnifiedLog.scala 
b/core/src/main/scala/kafka/log/UnifiedLog.scala
index 3f129deec7f..b3d6588de06 100644
--- a/core/src/main/scala/kafka/log/UnifiedLog.scala
+++ b/core/src/main/scala/kafka/log/UnifiedLog.scala
@@ -99,7 +99,7 @@ class UnifiedLog(@volatile var logStartOffset: Long,
                  private val localLog: LocalLog,
                  val brokerTopicStats: BrokerTopicStats,
                  val producerIdExpirationCheckIntervalMs: Int,
-                 @volatile var leaderEpochCache: Option[LeaderEpochFileCache],
+                 @volatile var leaderEpochCache: LeaderEpochFileCache,
                  val producerStateManager: ProducerStateManager,
                  @volatile private var _topicId: Option[Uuid],
                  val keepPartitionMetadataFile: Boolean,
@@ -508,9 +508,9 @@ class UnifiedLog(@volatile var logStartOffset: Long,
     }
   }
 
-  private def initializeLeaderEpochCache(): Unit = lock synchronized {
-    leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
-      dir, topicPartition, logDirFailureChannel, logIdent, leaderEpochCache, 
scheduler)
+  private def reinitializeLeaderEpochCache(): Unit = lock synchronized {
+    leaderEpochCache = UnifiedLog.createLeaderEpochCache(
+      dir, topicPartition, logDirFailureChannel, 
Option.apply(leaderEpochCache), scheduler)
   }
 
   private def updateHighWatermarkWithLogEndOffset(): Unit = {
@@ -672,10 +672,10 @@ class UnifiedLog(@volatile var logStartOffset: Long,
           if (shouldReinitialize) {
             // re-initialize leader epoch cache so that 
LeaderEpochCheckpointFile.checkpoint can correctly reference
             // the checkpoint file in renamed log directory
-            initializeLeaderEpochCache()
+            reinitializeLeaderEpochCache()
             initializePartitionMetadata()
           } else {
-            leaderEpochCache = None
+            leaderEpochCache.clear()
             partitionMetadataFile = None
           }
         }
@@ -713,6 +713,18 @@ class UnifiedLog(@volatile var logStartOffset: Long,
     append(records, origin, interBrokerProtocolVersion, 
validateAndAssignOffsets, leaderEpoch, Some(requestLocal), verificationGuard, 
ignoreRecordSize = false)
   }
 
+  /**
+   * Even though we always write to disk with record version v2 since Apache 
Kafka 4.0, older record versions may have
+   * been persisted to disk before that. In order to test such scenarios, we 
need the ability to append with older
+   * record versions. This method exists for that purpose and hence it should 
only be used from test code.
+   *
+   * Also see #appendAsLeader.
+   */
+  private[log] def appendAsLeaderWithRecordVersion(records: MemoryRecords, 
leaderEpoch: Int, recordVersion: RecordVersion): LogAppendInfo = {
+    append(records, AppendOrigin.CLIENT, MetadataVersion.latestProduction, 
true, leaderEpoch, Some(RequestLocal.noCaching),
+      VerificationGuard.SENTINEL, ignoreRecordSize = false, 
recordVersion.value)
+  }
+
   /**
    * Append this message set to the active segment of the local log without 
assigning offsets or Partition Leader Epochs
    *
@@ -757,7 +769,8 @@ class UnifiedLog(@volatile var logStartOffset: Long,
                      leaderEpoch: Int,
                      requestLocal: Option[RequestLocal],
                      verificationGuard: VerificationGuard,
-                     ignoreRecordSize: Boolean): LogAppendInfo = {
+                     ignoreRecordSize: Boolean,
+                     toMagic: Byte = RecordBatch.CURRENT_MAGIC_VALUE): 
LogAppendInfo = {
     // We want to ensure the partition metadata file is written to the log dir 
before any log data is written to disk.
     // This will ensure that any log data can be recovered with the correct 
topic ID in the case of failure.
     maybeFlushMetadataFile()
@@ -787,7 +800,7 @@ class UnifiedLog(@volatile var logStartOffset: Long,
                 appendInfo.sourceCompression,
                 targetCompression,
                 config.compact,
-                RecordBatch.CURRENT_MAGIC_VALUE,
+                toMagic,
                 config.messageTimestampType,
                 config.messageTimestampBeforeMaxMs,
                 config.messageTimestampAfterMaxMs,
@@ -850,14 +863,14 @@ class UnifiedLog(@volatile var logStartOffset: Long,
           // update the epoch cache with the epoch stamped onto the message by 
the leader
           validRecords.batches.forEach { batch =>
             if (batch.magic >= RecordBatch.MAGIC_VALUE_V2) {
-              maybeAssignEpochStartOffset(batch.partitionLeaderEpoch, 
batch.baseOffset)
+              assignEpochStartOffset(batch.partitionLeaderEpoch, 
batch.baseOffset)
             } else {
               // In partial upgrade scenarios, we may get a temporary 
regression to the message format. In
               // order to ensure the safety of leader election, we clear the 
epoch cache so that we revert
               // to truncation by high watermark after the next leader 
election.
-              leaderEpochCache.filter(_.nonEmpty).foreach { cache =>
+              if (leaderEpochCache.nonEmpty) {
                 warn(s"Clearing leader epoch cache after unexpected append 
with message format v${batch.magic}")
-                cache.clearAndFlush()
+                leaderEpochCache.clearAndFlush()
               }
             }
           }
@@ -928,23 +941,18 @@ class UnifiedLog(@volatile var logStartOffset: Long,
     }
   }
 
-  def maybeAssignEpochStartOffset(leaderEpoch: Int, startOffset: Long): Unit = 
{
-    leaderEpochCache.foreach { cache =>
-      cache.assign(leaderEpoch, startOffset)
-    }
-  }
+  def assignEpochStartOffset(leaderEpoch: Int, startOffset: Long): Unit =
+    leaderEpochCache.assign(leaderEpoch, startOffset)
 
-  def latestEpoch: Option[Int] = 
leaderEpochCache.flatMap(_.latestEpoch.toScala)
+  def latestEpoch: Option[Int] = leaderEpochCache.latestEpoch.toScala
 
   def endOffsetForEpoch(leaderEpoch: Int): Option[OffsetAndEpoch] = {
-    leaderEpochCache.flatMap { cache =>
-      val entry = cache.endOffsetFor(leaderEpoch, logEndOffset)
-      val (foundEpoch, foundOffset) = (entry.getKey, entry.getValue)
-      if (foundOffset == UNDEFINED_EPOCH_OFFSET)
-        None
-      else
-        Some(new OffsetAndEpoch(foundOffset, foundEpoch))
-    }
+    val entry = leaderEpochCache.endOffsetFor(leaderEpoch, logEndOffset)
+    val (foundEpoch, foundOffset) = (entry.getKey, entry.getValue)
+    if (foundOffset == UNDEFINED_EPOCH_OFFSET)
+      None
+    else
+      Some(new OffsetAndEpoch(foundOffset, foundEpoch))
   }
 
   private def maybeIncrementFirstUnstableOffset(): Unit = lock synchronized {
@@ -1004,7 +1012,7 @@ class UnifiedLog(@volatile var logStartOffset: Long,
           updatedLogStartOffset = true
           updateLogStartOffset(newLogStartOffset)
           info(s"Incremented log start offset to $newLogStartOffset due to 
$reason")
-          
leaderEpochCache.foreach(_.truncateFromStartAsyncFlush(logStartOffset))
+          leaderEpochCache.truncateFromStartAsyncFlush(logStartOffset)
           producerStateManager.onLogStartOffsetIncremented(newLogStartOffset)
           maybeIncrementFirstUnstableOffset()
         }
@@ -1271,7 +1279,7 @@ class UnifiedLog(@volatile var logStartOffset: Long,
         // The first cached epoch usually corresponds to the log start offset, 
but we have to verify this since
         // it may not be true following a message format version bump as the 
epoch will not be available for
         // log entries written in the older format.
-        val earliestEpochEntry = 
leaderEpochCache.toJava.flatMap(_.earliestEntry())
+        val earliestEpochEntry = leaderEpochCache.earliestEntry()
         val epochOpt = if (earliestEpochEntry.isPresent && 
earliestEpochEntry.get().startOffset <= logStartOffset) {
           Optional.of[Integer](earliestEpochEntry.get().epoch)
         } else Optional.empty[Integer]()
@@ -1280,41 +1288,24 @@ class UnifiedLog(@volatile var logStartOffset: Long,
       } else if (targetTimestamp == 
ListOffsetsRequest.EARLIEST_LOCAL_TIMESTAMP) {
         val curLocalLogStartOffset = localLogStartOffset()
 
-        val epochResult: Optional[Integer] =
-          if (leaderEpochCache.isDefined) {
-            val epochOpt = 
leaderEpochCache.get.epochForOffset(curLocalLogStartOffset)
-            if (epochOpt.isPresent) Optional.of(epochOpt.getAsInt) else 
Optional.empty()
-          } else {
-            Optional.empty()
-          }
+        val epochResult: Optional[Integer] = {
+          val epochOpt = 
leaderEpochCache.epochForOffset(curLocalLogStartOffset)
+          if (epochOpt.isPresent) Optional.of(epochOpt.getAsInt) else 
Optional.empty()
+        }
 
         new OffsetResultHolder(new 
TimestampAndOffset(RecordBatch.NO_TIMESTAMP, curLocalLogStartOffset, 
epochResult))
       } else if (targetTimestamp == ListOffsetsRequest.LATEST_TIMESTAMP) {
-        val epoch = leaderEpochCache match {
-          case Some(cache) =>
-            val latestEpoch = cache.latestEpoch()
-            if (latestEpoch.isPresent) 
Optional.of[Integer](latestEpoch.getAsInt) else Optional.empty[Integer]()
-          case None => Optional.empty[Integer]()
-        }
+        val latestEpoch = leaderEpochCache.latestEpoch()
+        val epoch = if (latestEpoch.isPresent) 
Optional.of[Integer](latestEpoch.getAsInt) else Optional.empty[Integer]()
         new OffsetResultHolder(new 
TimestampAndOffset(RecordBatch.NO_TIMESTAMP, logEndOffset, epoch))
       } else if (targetTimestamp == 
ListOffsetsRequest.LATEST_TIERED_TIMESTAMP) {
         if (remoteLogEnabled()) {
           val curHighestRemoteOffset = highestOffsetInRemoteStorage()
-
+          val epochOpt = 
leaderEpochCache.epochForOffset(curHighestRemoteOffset)
           val epochResult: Optional[Integer] =
-            if (leaderEpochCache.isDefined) {
-              val epochOpt = 
leaderEpochCache.get.epochForOffset(curHighestRemoteOffset)
-              if (epochOpt.isPresent) {
-                Optional.of(epochOpt.getAsInt)
-              } else if (curHighestRemoteOffset == -1) {
-                Optional.of(RecordBatch.NO_PARTITION_LEADER_EPOCH)
-              } else {
-                Optional.empty()
-              }
-            } else {
-              Optional.empty()
-            }
-
+            if (epochOpt.isPresent) Optional.of(epochOpt.getAsInt)
+            else if (curHighestRemoteOffset == -1) 
Optional.of(RecordBatch.NO_PARTITION_LEADER_EPOCH)
+            else Optional.empty()
           new OffsetResultHolder(new 
TimestampAndOffset(RecordBatch.NO_TIMESTAMP, curHighestRemoteOffset, 
epochResult))
         } else {
           new OffsetResultHolder(new 
TimestampAndOffset(RecordBatch.NO_TIMESTAMP, -1L, Optional.of(-1)))
@@ -1340,7 +1331,7 @@ class UnifiedLog(@volatile var logStartOffset: Long,
           }
 
           val asyncOffsetReadFutureHolder = 
remoteLogManager.get.asyncOffsetRead(topicPartition, targetTimestamp,
-            logStartOffset, leaderEpochCache.get, () => 
searchOffsetInLocalLog(targetTimestamp, localLogStartOffset()))
+            logStartOffset, leaderEpochCache, () => 
searchOffsetInLocalLog(targetTimestamp, localLogStartOffset()))
           
           new OffsetResultHolder(Optional.empty(), 
Optional.of(asyncOffsetReadFutureHolder))
         } else {
@@ -1768,7 +1759,7 @@ class UnifiedLog(@volatile var logStartOffset: Long,
       lock synchronized {
         localLog.checkIfMemoryMappedBufferClosed()
         producerExpireCheck.cancel(true)
-        leaderEpochCache.foreach(_.clear())
+        leaderEpochCache.clear()
         val deletedSegments = localLog.deleteAllSegments()
         deleteProducerSnapshots(deletedSegments, asyncDelete = false)
         localLog.deleteEmptyDir()
@@ -1821,7 +1812,7 @@ class UnifiedLog(@volatile var logStartOffset: Long,
         // and inserted the first start offset entry, but then failed to 
append any entries
         // before another leader was elected.
         lock synchronized {
-          leaderEpochCache.foreach(_.truncateFromEndAsyncFlush(logEndOffset))
+          leaderEpochCache.truncateFromEndAsyncFlush(logEndOffset)
         }
 
         false
@@ -1834,7 +1825,7 @@ class UnifiedLog(@volatile var logStartOffset: Long,
           } else {
             val deletedSegments = localLog.truncateTo(targetOffset)
             deleteProducerSnapshots(deletedSegments, asyncDelete = true)
-            leaderEpochCache.foreach(_.truncateFromEndAsyncFlush(targetOffset))
+            leaderEpochCache.truncateFromEndAsyncFlush(targetOffset)
             logStartOffset = math.min(targetOffset, logStartOffset)
             rebuildProducerState(targetOffset, producerStateManager)
             if (highWatermark >= localLog.logEndOffset)
@@ -1858,7 +1849,7 @@ class UnifiedLog(@volatile var logStartOffset: Long,
       debug(s"Truncate and start at offset $newOffset, logStartOffset: 
${logStartOffsetOpt.getOrElse(newOffset)}")
       lock synchronized {
         localLog.truncateFullyAndStartAt(newOffset)
-        leaderEpochCache.foreach(_.clearAndFlush())
+        leaderEpochCache.clearAndFlush()
         producerStateManager.truncateFullyAndStartAt(newOffset)
         logStartOffset = logStartOffsetOpt.getOrElse(newOffset)
         if (remoteLogEnabled()) _localLogStartOffset = newOffset
@@ -2015,11 +2006,10 @@ object UnifiedLog extends Logging {
     // The created leaderEpochCache will be truncated by LogLoader if necessary
     // so it is guaranteed that the epoch entries will be correct even when 
on-disk
     // checkpoint was stale (due to async nature of 
LeaderEpochFileCache#truncateFromStart/End).
-    val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
+    val leaderEpochCache = UnifiedLog.createLeaderEpochCache(
       dir,
       topicPartition,
       logDirFailureChannel,
-      s"[UnifiedLog partition=$topicPartition, dir=${dir.getParent}] ",
       None,
       scheduler)
     val producerStateManager = new ProducerStateManager(topicPartition, dir,
@@ -2036,7 +2026,7 @@ object UnifiedLog extends Logging {
       segments,
       logStartOffset,
       recoveryPoint,
-      leaderEpochCache.toJava,
+      leaderEpochCache,
       producerStateManager,
       numRemainingSegments,
       isRemoteLogEnabled,
@@ -2072,29 +2062,24 @@ object UnifiedLog extends Logging {
   def parseTopicPartitionName(dir: File): TopicPartition = 
LocalLog.parseTopicPartitionName(dir)
 
   /**
-   * If the recordVersion is >= RecordVersion.V2, create a new 
LeaderEpochFileCache instance.
-   * Loading the epoch entries from the backing checkpoint file or the 
provided currentCache if not empty.
-   * Otherwise, the message format is considered incompatible and the existing 
LeaderEpoch file
-   * is deleted.
+   * Create a new LeaderEpochFileCache instance and load the epoch entries 
from the backing checkpoint file or
+   * the provided currentCache (if not empty).
    *
    * @param dir                  The directory in which the log will reside
    * @param topicPartition       The topic partition
    * @param logDirFailureChannel The LogDirFailureChannel to asynchronously 
handle log dir failure
-   * @param logPrefix            The logging prefix
    * @param currentCache         The current LeaderEpochFileCache instance (if 
any)
    * @param scheduler            The scheduler for executing asynchronous tasks
    * @return The new LeaderEpochFileCache instance (if created), none otherwise
    */
-  def maybeCreateLeaderEpochCache(dir: File,
-                                  topicPartition: TopicPartition,
-                                  logDirFailureChannel: LogDirFailureChannel,
-                                  logPrefix: String,
-                                  currentCache: Option[LeaderEpochFileCache],
-                                  scheduler: Scheduler): 
Option[LeaderEpochFileCache] = {
+  def createLeaderEpochCache(dir: File,
+                             topicPartition: TopicPartition,
+                             logDirFailureChannel: LogDirFailureChannel,
+                             currentCache: Option[LeaderEpochFileCache],
+                             scheduler: Scheduler): LeaderEpochFileCache = {
     val leaderEpochFile = LeaderEpochCheckpointFile.newFile(dir)
     val checkpointFile = new LeaderEpochCheckpointFile(leaderEpochFile, 
logDirFailureChannel)
-    currentCache.map(_.withCheckpoint(checkpointFile))
-      .orElse(Some(new LeaderEpochFileCache(topicPartition, checkpointFile, 
scheduler)))
+    currentCache.map(_.withCheckpoint(checkpointFile)).getOrElse(new 
LeaderEpochFileCache(topicPartition, checkpointFile, scheduler))
   }
 
   private[log] def replaceSegments(existingSegments: LogSegments,
diff --git a/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala 
b/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala
index 919f4992d33..d3ab1f25ff3 100644
--- a/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala
+++ b/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala
@@ -197,7 +197,7 @@ final class KafkaMetadataLog private (
   }
 
   override def initializeLeaderEpoch(epoch: Int): Unit = {
-    log.maybeAssignEpochStartOffset(epoch, log.logEndOffset)
+    log.assignEpochStartOffset(epoch, log.logEndOffset)
   }
 
   override def updateHighWatermark(offsetMetadata: LogOffsetMetadata): Unit = {
diff --git a/core/src/main/scala/kafka/server/LocalLeaderEndPoint.scala 
b/core/src/main/scala/kafka/server/LocalLeaderEndPoint.scala
index 03258295a41..1e2a6cd033e 100644
--- a/core/src/main/scala/kafka/server/LocalLeaderEndPoint.scala
+++ b/core/src/main/scala/kafka/server/LocalLeaderEndPoint.scala
@@ -118,21 +118,21 @@ class LocalLeaderEndPoint(sourceBroker: BrokerEndPoint,
   override def fetchEarliestOffset(topicPartition: TopicPartition, 
currentLeaderEpoch: Int): OffsetAndEpoch = {
     val partition = replicaManager.getPartitionOrException(topicPartition)
     val logStartOffset = partition.localLogOrException.logStartOffset
-    val epoch = 
partition.localLogOrException.leaderEpochCache.get.epochForOffset(logStartOffset)
+    val epoch = 
partition.localLogOrException.leaderEpochCache.epochForOffset(logStartOffset)
     new OffsetAndEpoch(logStartOffset, epoch.orElse(0))
   }
 
   override def fetchLatestOffset(topicPartition: TopicPartition, 
currentLeaderEpoch: Int): OffsetAndEpoch = {
     val partition = replicaManager.getPartitionOrException(topicPartition)
     val logEndOffset = partition.localLogOrException.logEndOffset
-    val epoch = 
partition.localLogOrException.leaderEpochCache.get.epochForOffset(logEndOffset)
+    val epoch = 
partition.localLogOrException.leaderEpochCache.epochForOffset(logEndOffset)
     new OffsetAndEpoch(logEndOffset, epoch.orElse(0))
   }
 
   override def fetchEarliestLocalOffset(topicPartition: TopicPartition, 
currentLeaderEpoch: Int): OffsetAndEpoch = {
     val partition = replicaManager.getPartitionOrException(topicPartition)
     val localLogStartOffset = 
partition.localLogOrException.localLogStartOffset()
-    val epoch = 
partition.localLogOrException.leaderEpochCache.get.epochForOffset(localLogStartOffset)
+    val epoch = 
partition.localLogOrException.leaderEpochCache.epochForOffset(localLogStartOffset)
     new OffsetAndEpoch(localLogStartOffset, epoch.orElse(0))
   }
 
diff --git a/core/src/test/java/kafka/log/remote/RemoteLogManagerTest.java 
b/core/src/test/java/kafka/log/remote/RemoteLogManagerTest.java
index 2ae2a184670..4e8a3206352 100644
--- a/core/src/test/java/kafka/log/remote/RemoteLogManagerTest.java
+++ b/core/src/test/java/kafka/log/remote/RemoteLogManagerTest.java
@@ -279,7 +279,7 @@ public class RemoteLogManagerTest {
     void testGetLeaderEpochCheckpoint() {
         checkpoint.write(totalEpochEntries);
         LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
         assertEquals(totalEpochEntries, 
remoteLogManager.getLeaderEpochEntries(mockLog, 0, 300));
 
         List<EpochEntry> epochEntries = 
remoteLogManager.getLeaderEpochEntries(mockLog, 100, 200);
@@ -295,7 +295,7 @@ public class RemoteLogManagerTest {
         );
         checkpoint.write(totalEpochEntries);
         LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
         TopicIdPartition tpId = new TopicIdPartition(Uuid.randomUuid(), tp);
         OffsetAndEpoch offsetAndEpoch = 
remoteLogManager.findHighestRemoteOffset(tpId, mockLog);
         assertEquals(new OffsetAndEpoch(-1L, -1), offsetAndEpoch);
@@ -309,7 +309,7 @@ public class RemoteLogManagerTest {
         );
         checkpoint.write(totalEpochEntries);
         LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
         TopicIdPartition tpId = new TopicIdPartition(Uuid.randomUuid(), tp);
         when(remoteLogMetadataManager.highestOffsetForEpoch(eq(tpId), 
anyInt())).thenAnswer(ans -> {
             Integer epoch = ans.getArgument(1, Integer.class);
@@ -332,7 +332,7 @@ public class RemoteLogManagerTest {
         );
         checkpoint.write(totalEpochEntries);
         LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
         TopicIdPartition tpId = new TopicIdPartition(Uuid.randomUuid(), tp);
         when(remoteLogMetadataManager.highestOffsetForEpoch(eq(tpId), 
anyInt())).thenAnswer(ans -> {
             Integer epoch = ans.getArgument(1, Integer.class);
@@ -501,7 +501,7 @@ public class RemoteLogManagerTest {
         // leader epoch preparation
         checkpoint.write(totalEpochEntries);
         LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
         
when(remoteLogMetadataManager.highestOffsetForEpoch(any(TopicIdPartition.class),
 anyInt())).thenReturn(Optional.of(-1L));
 
         File tempFile = TestUtils.tempFile();
@@ -615,7 +615,7 @@ public class RemoteLogManagerTest {
         // leader epoch preparation
         checkpoint.write(totalEpochEntries);
         LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
         
when(remoteLogMetadataManager.highestOffsetForEpoch(any(TopicIdPartition.class),
 anyInt())).thenReturn(Optional.of(-1L));
 
         File tempFile = TestUtils.tempFile();
@@ -707,7 +707,7 @@ public class RemoteLogManagerTest {
         // leader epoch preparation
         checkpoint.write(totalEpochEntries);
         LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
         
when(remoteLogMetadataManager.highestOffsetForEpoch(any(TopicIdPartition.class),
 anyInt())).thenReturn(Optional.of(-1L));
 
         File tempFile = TestUtils.tempFile();
@@ -797,7 +797,7 @@ public class RemoteLogManagerTest {
         // leader epoch preparation
         checkpoint.write(totalEpochEntries);
         LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
         
when(remoteLogMetadataManager.highestOffsetForEpoch(any(TopicIdPartition.class),
 anyInt())).thenReturn(Optional.of(0L));
 
         File tempFile = TestUtils.tempFile();
@@ -916,7 +916,7 @@ public class RemoteLogManagerTest {
         // leader epoch preparation
         checkpoint.write(totalEpochEntries);
         LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
         
when(remoteLogMetadataManager.highestOffsetForEpoch(any(TopicIdPartition.class),
 anyInt()))
                 .thenReturn(Optional.of(0L))
                 .thenReturn(Optional.of(nextSegmentStartOffset - 1));
@@ -995,7 +995,7 @@ public class RemoteLogManagerTest {
         // simulate altering log dir completes, and the new partition leader 
changes to the same broker in different log dir (dir2)
         mockLog = mock(UnifiedLog.class);
         when(mockLog.parentDir()).thenReturn("dir2");
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
         when(mockLog.config()).thenReturn(logConfig);
         when(mockLog.logEndOffset()).thenReturn(500L);
 
@@ -1031,7 +1031,7 @@ public class RemoteLogManagerTest {
         // leader epoch preparation
         checkpoint.write(totalEpochEntries);
         LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
         
when(remoteLogMetadataManager.highestOffsetForEpoch(any(TopicIdPartition.class),
 anyInt())).thenReturn(Optional.of(0L));
 
         File tempFile = TestUtils.tempFile();
@@ -1195,7 +1195,7 @@ public class RemoteLogManagerTest {
         // leader epoch preparation
         checkpoint.write(totalEpochEntries);
         LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
         
when(remoteLogMetadataManager.highestOffsetForEpoch(any(TopicIdPartition.class),
 anyInt())).thenReturn(Optional.of(0L));
 
         File tempFile = TestUtils.tempFile();
@@ -1270,7 +1270,7 @@ public class RemoteLogManagerTest {
         // leader epoch preparation
         checkpoint.write(totalEpochEntries);
         LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
 
         // Throw a retryable exception so indicate that the remote log 
metadata manager is not initialized yet
         
when(remoteLogMetadataManager.highestOffsetForEpoch(any(TopicIdPartition.class),
 anyInt()))
@@ -1440,7 +1440,7 @@ public class RemoteLogManagerTest {
     public void testFindNextSegmentWithTxnIndex() throws 
RemoteStorageException {
         checkpoint.write(totalEpochEntries);
         LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
 
         
when(remoteLogMetadataManager.highestOffsetForEpoch(any(TopicIdPartition.class),
 anyInt()))
                 .thenReturn(Optional.of(0L));
@@ -1471,7 +1471,7 @@ public class RemoteLogManagerTest {
     public void testFindNextSegmentWithTxnIndexTraversesNextEpoch() throws 
RemoteStorageException {
         checkpoint.write(totalEpochEntries);
         LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
 
         
when(remoteLogMetadataManager.highestOffsetForEpoch(any(TopicIdPartition.class),
 anyInt()))
                 .thenReturn(Optional.of(0L));
@@ -1696,7 +1696,7 @@ public class RemoteLogManagerTest {
         epochEntries.add(new EpochEntry(5, 200L));
         checkpoint.write(epochEntries);
         LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
 
         long timestamp = time.milliseconds();
         RemoteLogSegmentMetadata metadata0 = new RemoteLogSegmentMetadata(new 
RemoteLogSegmentId(tpId, Uuid.randomUuid()),
@@ -2187,7 +2187,7 @@ public class RemoteLogManagerTest {
         checkpoint.write(epochEntries);
 
         LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
 
         long timestamp = time.milliseconds();
         int segmentSize = 1024;
@@ -2225,7 +2225,7 @@ public class RemoteLogManagerTest {
         checkpoint.write(epochEntries);
 
         LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
         when(mockLog.localLogStartOffset()).thenReturn(250L);
         
when(remoteLogMetadataManager.listRemoteLogSegments(eq(leaderTopicIdPartition), 
anyInt()))
                 .thenReturn(Collections.emptyIterator());
@@ -2250,7 +2250,7 @@ public class RemoteLogManagerTest {
         checkpoint.write(epochEntries);
 
         LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
 
         RemoteLogSegmentMetadata metadata = 
mock(RemoteLogSegmentMetadata.class);
         when(metadata.startOffset()).thenReturn(600L);
@@ -2350,7 +2350,7 @@ public class RemoteLogManagerTest {
         // leader epoch preparation
         checkpoint.write(Collections.singletonList(epochEntry0));
         LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
 
         // create 2 log segments, with 0 and 150 as log start offset
         LogSegment oldSegment = mock(LogSegment.class);
@@ -2455,7 +2455,7 @@ public class RemoteLogManagerTest {
         List<EpochEntry> epochEntries = Collections.singletonList(epochEntry0);
         checkpoint.write(epochEntries);
         LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
 
         
when(mockLog.topicPartition()).thenReturn(leaderTopicIdPartition.topicPartition());
         when(mockLog.logEndOffset()).thenReturn(200L);
@@ -2507,7 +2507,7 @@ public class RemoteLogManagerTest {
         List<EpochEntry> epochEntries = Collections.singletonList(epochEntry0);
         checkpoint.write(epochEntries);
         LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
 
         
when(mockLog.topicPartition()).thenReturn(leaderTopicIdPartition.topicPartition());
         when(mockLog.logEndOffset()).thenReturn(200L);
@@ -2575,7 +2575,7 @@ public class RemoteLogManagerTest {
         List<EpochEntry> epochEntries = Collections.singletonList(epochEntry0);
         checkpoint.write(epochEntries);
         LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
 
         
when(mockLog.topicPartition()).thenReturn(leaderTopicIdPartition.topicPartition());
         when(mockLog.logEndOffset()).thenReturn(200L);
@@ -2622,7 +2622,7 @@ public class RemoteLogManagerTest {
         List<EpochEntry> epochEntries = Collections.singletonList(epochEntry0);
         checkpoint.write(epochEntries);
         LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
         
when(mockLog.topicPartition()).thenReturn(leaderTopicIdPartition.topicPartition());
         when(mockLog.logEndOffset()).thenReturn(2000L);
 
@@ -2716,7 +2716,7 @@ public class RemoteLogManagerTest {
 
         checkpoint.write(epochEntries);
         LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
 
         Map<String, Long> logProps = new HashMap<>();
         logProps.put("retention.bytes", -1L);
@@ -2786,7 +2786,7 @@ public class RemoteLogManagerTest {
 
         checkpoint.write(epochEntries);
         LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
 
         assertDoesNotThrow(leaderTask::cleanupExpiredRemoteLogSegments);
 
@@ -2806,7 +2806,7 @@ public class RemoteLogManagerTest {
         List<EpochEntry> epochEntries = Collections.singletonList(epochEntry0);
         checkpoint.write(epochEntries);
         LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
 
         
when(mockLog.topicPartition()).thenReturn(leaderTopicIdPartition.topicPartition());
         when(mockLog.logEndOffset()).thenReturn(200L);
@@ -2876,7 +2876,7 @@ public class RemoteLogManagerTest {
 
         long localLogStartOffset = (long) segmentCount * recordsPerSegment;
         long logEndOffset = ((long) segmentCount * recordsPerSegment) + 1;
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
         when(mockLog.localLogStartOffset()).thenReturn(localLogStartOffset);
         when(mockLog.logEndOffset()).thenReturn(logEndOffset);
         
when(mockLog.onlyLocalLogSegmentsSize()).thenReturn(localLogSegmentsSize);
@@ -2914,7 +2914,7 @@ public class RemoteLogManagerTest {
 
         long localLogStartOffset = (long) segmentCount * recordsPerSegment;
         long logEndOffset = ((long) segmentCount * recordsPerSegment) + 1;
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
         when(mockLog.localLogStartOffset()).thenReturn(localLogStartOffset);
         when(mockLog.logEndOffset()).thenReturn(logEndOffset);
         
when(mockLog.onlyLocalLogSegmentsSize()).thenReturn(localLogSegmentsSize);
@@ -3001,7 +3001,7 @@ public class RemoteLogManagerTest {
 
             checkpoint.write(epochEntries);
             LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, 
checkpoint, scheduler);
-            when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+            when(mockLog.leaderEpochCache()).thenReturn(cache);
 
             Map<String, Long> logProps = new HashMap<>();
             logProps.put("retention.bytes", -1L);
@@ -3119,7 +3119,7 @@ public class RemoteLogManagerTest {
 
         
when(remoteStorageManager.fetchLogSegment(any(RemoteLogSegmentMetadata.class), 
anyInt()))
                 .thenAnswer(a -> fileInputStream);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
 
         int fetchOffset = 0;
         int fetchMaxBytes = 10;
@@ -3149,21 +3149,25 @@ public class RemoteLogManagerTest {
                 return remoteLogMetadataManager;
             }
 
+            @Override
             public Optional<RemoteLogSegmentMetadata> 
fetchRemoteLogSegmentMetadata(TopicPartition topicPartition,
                                                                                
     int epochForOffset, long offset) {
                 return Optional.of(segmentMetadata);
             }
 
+            @Override
             public Optional<RemoteLogSegmentMetadata> 
findNextSegmentMetadata(RemoteLogSegmentMetadata segmentMetadata,
-                                                                              
Option<LeaderEpochFileCache> leaderEpochFileCacheOption) {
+                                                                              
LeaderEpochFileCache leaderEpochFileCacheOption) {
                 return Optional.empty();
             }
 
+            @Override
             int lookupPositionForOffset(RemoteLogSegmentMetadata 
remoteLogSegmentMetadata, long offset) {
                 return 1;
             }
 
             // This is the key scenario that we are testing here
+            @Override
             EnrichedRecordBatch findFirstBatch(RemoteLogInputStream 
remoteLogInputStream, long offset) {
                 return new EnrichedRecordBatch(null, 0);
             }
@@ -3189,7 +3193,7 @@ public class RemoteLogManagerTest {
 
         
when(remoteStorageManager.fetchLogSegment(any(RemoteLogSegmentMetadata.class), 
anyInt()))
                 .thenAnswer(a -> fileInputStream);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
 
         int fetchOffset = 0;
         int fetchMaxBytes = 10;
@@ -3264,7 +3268,7 @@ public class RemoteLogManagerTest {
         RemoteLogSegmentMetadata segmentMetadata = 
mock(RemoteLogSegmentMetadata.class);
         LeaderEpochFileCache cache = mock(LeaderEpochFileCache.class);
         when(cache.epochForOffset(anyLong())).thenReturn(OptionalInt.of(1));
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
 
         int fetchOffset = 0;
         int fetchMaxBytes = 10;
@@ -3469,7 +3473,7 @@ public class RemoteLogManagerTest {
         // leader epoch preparation
         checkpoint.write(totalEpochEntries);
         LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
         when(mockLog.parentDir()).thenReturn("dir1");
         
when(remoteLogMetadataManager.highestOffsetForEpoch(any(TopicIdPartition.class),
 anyInt())).thenReturn(Optional.of(0L));
 
@@ -3532,7 +3536,7 @@ public class RemoteLogManagerTest {
         // leader epoch preparation
         checkpoint.write(totalEpochEntries);
         LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
         
when(remoteLogMetadataManager.highestOffsetForEpoch(any(TopicIdPartition.class),
 anyInt())).thenReturn(Optional.of(0L));
 
         // create 3 log segments
@@ -3631,7 +3635,7 @@ public class RemoteLogManagerTest {
     public void testRemoteReadFetchDataInfo() throws RemoteStorageException, 
IOException {
         checkpoint.write(totalEpochEntries);
         LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint, 
scheduler);
-        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        when(mockLog.leaderEpochCache()).thenReturn(cache);
         
when(remoteLogMetadataManager.remoteLogSegmentMetadata(eq(leaderTopicIdPartition),
 anyInt(), anyLong()))
                 .thenAnswer(ans -> {
                     long offset = ans.getArgument(2);
diff --git a/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala 
b/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala
index 8bdc80e9a4f..c96eef55420 100644
--- a/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala
+++ b/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala
@@ -50,7 +50,6 @@ import org.mockito.Mockito.{mock, when}
 
 import scala.concurrent.duration._
 import scala.jdk.CollectionConverters._
-import scala.jdk.OptionConverters.RichOption
 
 /**
  * Verifies that slow appends to log don't block request threads processing 
replica fetch requests.
@@ -302,8 +301,8 @@ class PartitionLockTest extends Logging {
         val log = super.createLog(isNew, isFutureReplica, offsetCheckpoints, 
None, None)
         val logDirFailureChannel = new LogDirFailureChannel(1)
         val segments = new LogSegments(log.topicPartition)
-        val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
-          log.dir, log.topicPartition, logDirFailureChannel, "", None, 
mockTime.scheduler)
+        val leaderEpochCache = UnifiedLog.createLeaderEpochCache(
+          log.dir, log.topicPartition, logDirFailureChannel, None, 
mockTime.scheduler)
         val maxTransactionTimeout = 5 * 60 * 1000
         val producerStateManagerConfig = new 
ProducerStateManagerConfig(TransactionLogConfig.PRODUCER_ID_EXPIRATION_MS_DEFAULT,
 false)
         val producerStateManager = new ProducerStateManager(
@@ -324,7 +323,7 @@ class PartitionLockTest extends Logging {
           segments,
           0L,
           0L,
-          leaderEpochCache.toJava,
+          leaderEpochCache,
           producerStateManager,
           new ConcurrentHashMap[String, Integer],
           false
@@ -444,7 +443,7 @@ class PartitionLockTest extends Logging {
     log: UnifiedLog,
     logStartOffset: Long,
     localLog: LocalLog,
-    leaderEpochCache: Option[LeaderEpochFileCache],
+    leaderEpochCache: LeaderEpochFileCache,
     producerStateManager: ProducerStateManager,
     appendSemaphore: Semaphore
   ) extends UnifiedLog(
diff --git a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala 
b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
index b8ddaae026a..3dbcb952fa0 100644
--- a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
+++ b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
@@ -445,8 +445,8 @@ class PartitionTest extends AbstractPartitionTest {
         val log = super.createLog(isNew, isFutureReplica, offsetCheckpoints, 
None, None)
         val logDirFailureChannel = new LogDirFailureChannel(1)
         val segments = new LogSegments(log.topicPartition)
-        val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
-          log.dir, log.topicPartition, logDirFailureChannel, "", None, 
time.scheduler)
+        val leaderEpochCache = UnifiedLog.createLeaderEpochCache(
+          log.dir, log.topicPartition, logDirFailureChannel, None, 
time.scheduler)
         val maxTransactionTimeoutMs = 5 * 60 * 1000
         val producerStateManagerConfig = new 
ProducerStateManagerConfig(TransactionLogConfig.PRODUCER_ID_EXPIRATION_MS_DEFAULT,
 true)
         val producerStateManager = new ProducerStateManager(
@@ -467,7 +467,7 @@ class PartitionTest extends AbstractPartitionTest {
           segments,
           0L,
           0L,
-          leaderEpochCache.asJava,
+          leaderEpochCache,
           producerStateManager,
           new ConcurrentHashMap[String, Integer],
           false
@@ -3124,7 +3124,7 @@ class PartitionTest extends AbstractPartitionTest {
     assertEquals(Some(0L), partition.leaderEpochStartOffsetOpt)
 
     val leaderLog = partition.localLogOrException
-    assertEquals(Optional.of(new EpochEntry(leaderEpoch, 0L)), 
leaderLog.leaderEpochCache.toJava.flatMap(_.latestEntry))
+    assertEquals(Optional.of(new EpochEntry(leaderEpoch, 0L)), 
leaderLog.leaderEpochCache.latestEntry)
 
     // Write to the log to increment the log end offset.
     leaderLog.appendAsLeader(MemoryRecords.withRecords(0L, Compression.NONE, 0,
@@ -3148,7 +3148,7 @@ class PartitionTest extends AbstractPartitionTest {
     assertEquals(leaderEpoch, partition.getLeaderEpoch)
     assertEquals(Set(leaderId), partition.partitionState.isr)
     assertEquals(Some(0L), partition.leaderEpochStartOffsetOpt)
-    assertEquals(Optional.of(new EpochEntry(leaderEpoch, 0L)), 
leaderLog.leaderEpochCache.toJava.flatMap(_.latestEntry))
+    assertEquals(Optional.of(new EpochEntry(leaderEpoch, 0L)), 
leaderLog.leaderEpochCache.latestEntry)
   }
 
   @Test
@@ -3628,7 +3628,7 @@ class PartitionTest extends AbstractPartitionTest {
     log: UnifiedLog,
     logStartOffset: Long,
     localLog: LocalLog,
-    leaderEpochCache: Option[LeaderEpochFileCache],
+    leaderEpochCache: LeaderEpochFileCache,
     producerStateManager: ProducerStateManager,
     appendSemaphore: Semaphore
   ) extends UnifiedLog(
diff --git 
a/core/src/test/scala/unit/kafka/log/AbstractLogCleanerIntegrationTest.scala 
b/core/src/test/scala/unit/kafka/log/AbstractLogCleanerIntegrationTest.scala
index 87470168527..e0a6724d081 100644
--- a/core/src/test/scala/unit/kafka/log/AbstractLogCleanerIntegrationTest.scala
+++ b/core/src/test/scala/unit/kafka/log/AbstractLogCleanerIntegrationTest.scala
@@ -24,7 +24,7 @@ import kafka.utils.Implicits._
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.compress.Compression
 import org.apache.kafka.common.config.TopicConfig
-import org.apache.kafka.common.record.{MemoryRecords, RecordBatch}
+import org.apache.kafka.common.record.{MemoryRecords, RecordBatch, 
RecordVersion}
 import org.apache.kafka.common.utils.Utils
 import org.apache.kafka.coordinator.transaction.TransactionLogConfig
 import org.apache.kafka.server.util.MockTime
@@ -147,8 +147,8 @@ abstract class AbstractLogCleanerIntegrationTest {
                 startKey: Int = 0, magicValue: Byte = 
RecordBatch.CURRENT_MAGIC_VALUE): Seq[(Int, String, Long)] = {
     for (_ <- 0 until numDups; key <- startKey until (startKey + numKeys)) 
yield {
       val value = counter.toString
-      val appendInfo = log.appendAsLeader(TestUtils.singletonRecords(value = 
value.getBytes, codec = codec,
-        key = key.toString.getBytes, magicValue = magicValue), leaderEpoch = 0)
+      val appendInfo = 
log.appendAsLeaderWithRecordVersion(TestUtils.singletonRecords(value = 
value.getBytes, codec = codec,
+        key = key.toString.getBytes, magicValue = magicValue), leaderEpoch = 
0, recordVersion = RecordVersion.lookup(magicValue))
       // move LSO forward to increase compaction bound
       log.updateHighWatermark(log.logEndOffset)
       incCounter()
diff --git a/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala 
b/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala
index 974da551e77..796536780b1 100644
--- a/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala
@@ -37,7 +37,6 @@ import java.lang.{Long => JLong}
 import java.util
 import java.util.concurrent.ConcurrentHashMap
 import scala.collection.mutable
-import scala.jdk.OptionConverters.RichOption
 
 /**
   * Unit tests for the log cleaning logic
@@ -110,8 +109,8 @@ class LogCleanerManagerTest extends Logging {
     val maxTransactionTimeoutMs = 5 * 60 * 1000
     val producerIdExpirationCheckIntervalMs = 
TransactionLogConfig.PRODUCER_ID_EXPIRATION_CHECK_INTERVAL_MS_DEFAULT
     val segments = new LogSegments(tp)
-    val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
-      tpDir, topicPartition, logDirFailureChannel, "", None, time.scheduler)
+    val leaderEpochCache = UnifiedLog.createLeaderEpochCache(
+      tpDir, topicPartition, logDirFailureChannel, None, time.scheduler)
     val producerStateManager = new ProducerStateManager(topicPartition, tpDir, 
maxTransactionTimeoutMs, producerStateManagerConfig, time)
     val offsets = new LogLoader(
       tpDir,
@@ -124,7 +123,7 @@ class LogCleanerManagerTest extends Logging {
       segments,
       0L,
       0L,
-      leaderEpochCache.toJava,
+      leaderEpochCache,
       producerStateManager,
       new ConcurrentHashMap[String, Integer],
       false
diff --git 
a/core/src/test/scala/unit/kafka/log/LogCleanerParameterizedIntegrationTest.scala
 
b/core/src/test/scala/unit/kafka/log/LogCleanerParameterizedIntegrationTest.scala
index d0a7624ed79..df461855a9f 100755
--- 
a/core/src/test/scala/unit/kafka/log/LogCleanerParameterizedIntegrationTest.scala
+++ 
b/core/src/test/scala/unit/kafka/log/LogCleanerParameterizedIntegrationTest.scala
@@ -25,10 +25,11 @@ import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.compress.Compression
 import org.apache.kafka.common.config.TopicConfig
 import org.apache.kafka.common.record._
+import org.apache.kafka.common.utils.Time
 import org.apache.kafka.server.config.ServerConfigs
 import org.apache.kafka.server.util.MockTime
 import org.apache.kafka.storage.internals.checkpoint.OffsetCheckpointFile
-import org.apache.kafka.storage.internals.log.CleanerConfig
+import org.apache.kafka.storage.internals.log.{CleanerConfig, LogConfig}
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.extension.ExtensionContext
 import org.junit.jupiter.params.ParameterizedTest
@@ -134,6 +135,131 @@ class LogCleanerParameterizedIntegrationTest extends 
AbstractLogCleanerIntegrati
     assertEquals(toMap(messages), toMap(read), "Contents of the map shouldn't 
change")
   }
 
+  @ParameterizedTest
+  @ArgumentsSource(classOf[LogCleanerParameterizedIntegrationTest.ExcludeZstd])
+  def testCleanerWithMessageFormatV0V1V2(compressionType: CompressionType): 
Unit = {
+    val compression = Compression.of(compressionType).build()
+    val largeMessageKey = 20
+    val (largeMessageValue, largeMessageSet) = 
createLargeSingleMessageSet(largeMessageKey, RecordBatch.MAGIC_VALUE_V0, 
compression)
+    val maxMessageSize = compression match {
+      case Compression.NONE => largeMessageSet.sizeInBytes
+      case _ =>
+        // the broker assigns absolute offsets for message format 0 which 
potentially causes the compressed size to
+        // increase because the broker offsets are larger than the ones 
assigned by the client
+        // adding `6` to the message set size is good enough for this test: it 
covers the increased message size while
+        // still being less than the overhead introduced by the conversion 
from message format version 0 to 1
+        largeMessageSet.sizeInBytes + 6
+    }
+
+    cleaner = makeCleaner(partitions = topicPartitions, maxMessageSize = 
maxMessageSize)
+
+    val log = cleaner.logs.get(topicPartitions(0))
+    val props = logConfigProperties(maxMessageSize = maxMessageSize)
+    props.put(TopicConfig.MESSAGE_TIMESTAMP_TYPE_CONFIG, 
TimestampType.LOG_APPEND_TIME.name)
+    val logConfig = new LogConfig(props)
+    log.updateConfig(logConfig)
+
+    val appends1 = writeDups(numKeys = 100, numDups = 3, log = log, codec = 
compression, magicValue = RecordBatch.MAGIC_VALUE_V0)
+    val startSize = log.size
+    cleaner.startup()
+
+    val firstDirty = log.activeSegment.baseOffset
+    checkLastCleaned("log", 0, firstDirty)
+    val compactedSize = log.logSegments.asScala.map(_.size).sum
+    assertTrue(startSize > compactedSize, s"log should have been compacted: 
startSize=$startSize compactedSize=$compactedSize")
+
+    checkLogAfterAppendingDups(log, startSize, appends1)
+
+    val dupsV0 = writeDups(numKeys = 40, numDups = 3, log = log, codec = 
compression, magicValue = RecordBatch.MAGIC_VALUE_V0)
+    val appendInfo = log.appendAsLeaderWithRecordVersion(largeMessageSet, 
leaderEpoch = 0, recordVersion = RecordVersion.V0)
+    // move LSO forward to increase compaction bound
+    log.updateHighWatermark(log.logEndOffset)
+    val largeMessageOffset = appendInfo.firstOffset
+
+    // also add some messages with version 1 and version 2 to check that we 
handle mixed format versions correctly
+    val dupsV1 = writeDups(startKey = 30, numKeys = 40, numDups = 3, log = 
log, codec = compression, magicValue = RecordBatch.MAGIC_VALUE_V1)
+    val dupsV2 = writeDups(startKey = 15, numKeys = 5, numDups = 3, log = log, 
codec = compression, magicValue = RecordBatch.MAGIC_VALUE_V2)
+
+    val v0RecordKeysWithNoV1V2Updates = (appends1.map(_._1).toSet -- 
dupsV1.map(_._1) -- dupsV2.map(_._1)).map(_.toString)
+    val appends2: Seq[(Int, String, Long)] =
+      appends1 ++ dupsV0 ++ Seq((largeMessageKey, largeMessageValue, 
largeMessageOffset)) ++ dupsV1 ++ dupsV2
+
+    // roll the log so that all appended messages can be compacted
+    log.roll()
+    val firstDirty2 = log.activeSegment.baseOffset
+    checkLastCleaned("log", 0, firstDirty2)
+
+    checkLogAfterAppendingDups(log, startSize, appends2)
+    checkLogAfterConvertingToV2(compressionType, log, 
logConfig.messageTimestampType, v0RecordKeysWithNoV1V2Updates)
+  }
+
+  @ParameterizedTest
+  @ArgumentsSource(classOf[LogCleanerParameterizedIntegrationTest.ExcludeZstd])
+  def testCleaningNestedMessagesWithV0V1(compressionType: CompressionType): 
Unit = {
+    val compression = Compression.of(compressionType).build()
+    val maxMessageSize = 192
+    cleaner = makeCleaner(partitions = topicPartitions, maxMessageSize = 
maxMessageSize, segmentSize = 256)
+
+    val log = cleaner.logs.get(topicPartitions(0))
+    val logConfig = new LogConfig(logConfigProperties(maxMessageSize = 
maxMessageSize, segmentSize = 256))
+    log.updateConfig(logConfig)
+
+    // with compression enabled, these messages will be written as a single 
message containing all the individual messages
+    var appendsV0 = writeDupsSingleMessageSet(numKeys = 2, numDups = 3, log = 
log, codec = compression, magicValue = RecordBatch.MAGIC_VALUE_V0)
+    appendsV0 ++= writeDupsSingleMessageSet(numKeys = 2, startKey = 3, numDups 
= 2, log = log, codec = compression, magicValue = RecordBatch.MAGIC_VALUE_V0)
+
+    var appendsV1 = writeDupsSingleMessageSet(startKey = 4, numKeys = 2, 
numDups = 2, log = log, codec = compression, magicValue = 
RecordBatch.MAGIC_VALUE_V1)
+    appendsV1 ++= writeDupsSingleMessageSet(startKey = 4, numKeys = 2, numDups 
= 2, log = log, codec = compression, magicValue = RecordBatch.MAGIC_VALUE_V1)
+    appendsV1 ++= writeDupsSingleMessageSet(startKey = 6, numKeys = 2, numDups 
= 2, log = log, codec = compression, magicValue = RecordBatch.MAGIC_VALUE_V1)
+
+    val appends = appendsV0 ++ appendsV1
+
+    val v0RecordKeysWithNoV1V2Updates = (appendsV0.map(_._1).toSet -- 
appendsV1.map(_._1)).map(_.toString)
+
+    // roll the log so that all appended messages can be compacted
+    log.roll()
+    val startSize = log.size
+    cleaner.startup()
+
+    val firstDirty = log.activeSegment.baseOffset
+    assertTrue(firstDirty >= appends.size) // ensure we clean data from V0 and 
V1
+
+    checkLastCleaned("log", 0, firstDirty)
+    val compactedSize = log.logSegments.asScala.map(_.size).sum
+    assertTrue(startSize > compactedSize, s"log should have been compacted: 
startSize=$startSize compactedSize=$compactedSize")
+
+    checkLogAfterAppendingDups(log, startSize, appends)
+    checkLogAfterConvertingToV2(compressionType, log, 
logConfig.messageTimestampType, v0RecordKeysWithNoV1V2Updates)
+  }
+
+  private def checkLogAfterConvertingToV2(compressionType: CompressionType, 
log: UnifiedLog, timestampType: TimestampType,
+                                          keysForV0RecordsWithNoV1V2Updates: 
Set[String]): Unit = {
+    for (segment <- log.logSegments.asScala; recordBatch <- 
segment.log.batches.asScala) {
+      // Uncompressed v0/v1 records are always converted into single record v2 
batches via compaction if they are retained
+      // Compressed v0/v1 record batches are converted into record batches v2 
with one or more records (depending on the
+      // number of retained records after compaction)
+      assertEquals(RecordVersion.V2.value, recordBatch.magic)
+      if (compressionType == CompressionType.NONE)
+        assertEquals(1, recordBatch.iterator().asScala.size)
+      else
+        assertTrue(recordBatch.iterator().asScala.size >= 1)
+
+      val firstRecordKey = 
TestUtils.readString(recordBatch.iterator().next().key())
+      if (keysForV0RecordsWithNoV1V2Updates.contains(firstRecordKey))
+        assertEquals(TimestampType.CREATE_TIME, recordBatch.timestampType)
+      else
+        assertEquals(timestampType, recordBatch.timestampType)
+
+      recordBatch.iterator.asScala.foreach { record =>
+        val recordKey = TestUtils.readString(record.key)
+        if (keysForV0RecordsWithNoV1V2Updates.contains(recordKey))
+          assertEquals(RecordBatch.NO_TIMESTAMP, record.timestamp, "Record " + 
recordKey + " with unexpected timestamp ")
+        else
+          assertNotEquals(RecordBatch.NO_TIMESTAMP, record.timestamp, "Record 
" + recordKey + " with unexpected timestamp " + RecordBatch.NO_TIMESTAMP)
+      }
+    }
+  }
+
   @ParameterizedTest
   
@ArgumentsSource(classOf[LogCleanerParameterizedIntegrationTest.AllCompressions])
   def cleanerConfigUpdateTest(compressionType: CompressionType): Unit = {
@@ -213,6 +339,28 @@ class LogCleanerParameterizedIntegrationTest extends 
AbstractLogCleanerIntegrati
       (key, value, deepLogEntry.offset)
     }
   }
+
+  private def writeDupsSingleMessageSet(numKeys: Int, numDups: Int, log: 
UnifiedLog, codec: Compression,
+                                        startKey: Int = 0, magicValue: Byte): 
Seq[(Int, String, Long)] = {
+    val kvs = for (_ <- 0 until numDups; key <- startKey until (startKey + 
numKeys)) yield {
+      val payload = counter.toString
+      incCounter()
+      (key, payload)
+    }
+
+    val records = kvs.map { case (key, payload) =>
+      new SimpleRecord(Time.SYSTEM.milliseconds(), key.toString.getBytes, 
payload.getBytes)
+    }
+
+    val appendInfo = 
log.appendAsLeaderWithRecordVersion(MemoryRecords.withRecords(magicValue, 
codec, records: _*),
+      leaderEpoch = 0, recordVersion = RecordVersion.lookup(magicValue))
+    // move LSO forward to increase compaction bound
+    log.updateHighWatermark(log.logEndOffset)
+    val offsets = appendInfo.firstOffset to appendInfo.lastOffset
+
+    kvs.zip(offsets).map { case (kv, offset) => (kv._1, kv._2, offset) }
+  }
+
 }
 
 object LogCleanerParameterizedIntegrationTest {
diff --git a/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala 
b/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
index e4ebcb2d5da..9100cc7af21 100644
--- a/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
@@ -46,7 +46,6 @@ import java.util.Properties
 import java.util.concurrent.{ConcurrentHashMap, CountDownLatch, TimeUnit}
 import scala.collection._
 import scala.jdk.CollectionConverters._
-import scala.jdk.OptionConverters.RichOption
 
 /**
  * Unit tests for the log cleaning logic
@@ -189,8 +188,8 @@ class LogCleanerTest extends Logging {
     val maxTransactionTimeoutMs = 5 * 60 * 1000
     val producerIdExpirationCheckIntervalMs = 
TransactionLogConfig.PRODUCER_ID_EXPIRATION_CHECK_INTERVAL_MS_DEFAULT
     val logSegments = new LogSegments(topicPartition)
-    val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
-      dir, topicPartition, logDirFailureChannel, "", None, time.scheduler)
+    val leaderEpochCache = UnifiedLog.createLeaderEpochCache(
+      dir, topicPartition, logDirFailureChannel, None, time.scheduler)
     val producerStateManager = new ProducerStateManager(topicPartition, dir,
       maxTransactionTimeoutMs, producerStateManagerConfig, time)
     val offsets = new LogLoader(
@@ -204,7 +203,7 @@ class LogCleanerTest extends Logging {
       logSegments,
       0L,
       0L,
-      leaderEpochCache.toJava,
+      leaderEpochCache,
       producerStateManager,
       new ConcurrentHashMap[String, Integer],
       false
diff --git a/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala 
b/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
index d6324d95c3a..8043c53e30c 100644
--- a/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
@@ -52,7 +52,7 @@ import java.util.{Optional, OptionalLong, Properties}
 import scala.collection.mutable.ListBuffer
 import scala.collection.{Iterable, Map, mutable}
 import scala.jdk.CollectionConverters._
-import scala.jdk.OptionConverters.{RichOption, RichOptional}
+import scala.jdk.OptionConverters.RichOptional
 
 class LogLoaderTest {
   var config: KafkaConfig = _
@@ -155,13 +155,13 @@ class LogLoaderTest {
           val logStartOffset = logStartOffsets.getOrDefault(topicPartition, 0L)
           val logDirFailureChannel: LogDirFailureChannel = new 
LogDirFailureChannel(1)
           val segments = new LogSegments(topicPartition)
-          val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
-            logDir, topicPartition, logDirFailureChannel, "", None, 
time.scheduler)
+          val leaderEpochCache = UnifiedLog.createLeaderEpochCache(
+            logDir, topicPartition, logDirFailureChannel, None, time.scheduler)
           val producerStateManager = new ProducerStateManager(topicPartition, 
logDir,
             this.maxTransactionTimeoutMs, this.producerStateManagerConfig, 
time)
           val logLoader = new LogLoader(logDir, topicPartition, config, 
time.scheduler, time,
             logDirFailureChannel, hadCleanShutdown, segments, logStartOffset, 
logRecoveryPoint,
-            leaderEpochCache.toJava, producerStateManager, new 
ConcurrentHashMap[String, Integer], false)
+            leaderEpochCache, producerStateManager, new 
ConcurrentHashMap[String, Integer], false)
           val offsets = logLoader.load()
           val localLog = new LocalLog(logDir, logConfig, segments, 
offsets.recoveryPoint,
             offsets.nextOffsetMetadata, mockTime.scheduler, mockTime, 
topicPartition,
@@ -357,13 +357,13 @@ class LogLoaderTest {
           }.when(wrapper).read(ArgumentMatchers.any(), ArgumentMatchers.any(), 
ArgumentMatchers.any(), ArgumentMatchers.any())
           Mockito.doAnswer { in =>
             recoveredSegments += wrapper
-            segment.recover(in.getArgument(0, classOf[ProducerStateManager]), 
in.getArgument(1, classOf[Optional[LeaderEpochFileCache]]))
+            segment.recover(in.getArgument(0, classOf[ProducerStateManager]), 
in.getArgument(1, classOf[LeaderEpochFileCache]))
           }.when(wrapper).recover(ArgumentMatchers.any(), 
ArgumentMatchers.any())
           super.add(wrapper)
         }
       }
-      val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
-        logDir, topicPartition, logDirFailureChannel, "", None, 
mockTime.scheduler)
+      val leaderEpochCache = UnifiedLog.createLeaderEpochCache(
+        logDir, topicPartition, logDirFailureChannel, None, mockTime.scheduler)
       val producerStateManager = new ProducerStateManager(topicPartition, 
logDir,
         maxTransactionTimeoutMs, producerStateManagerConfig, mockTime)
       val logLoader = new LogLoader(
@@ -377,7 +377,7 @@ class LogLoaderTest {
         interceptedLogSegments,
         0L,
         recoveryPoint,
-        leaderEpochCache.toJava,
+        leaderEpochCache,
         producerStateManager,
         new ConcurrentHashMap[String, Integer],
         false
@@ -430,8 +430,8 @@ class LogLoaderTest {
     val logDirFailureChannel: LogDirFailureChannel = new 
LogDirFailureChannel(1)
     val config = new LogConfig(new Properties())
     val segments = new LogSegments(topicPartition)
-    val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
-      logDir, topicPartition, logDirFailureChannel, "", None, 
mockTime.scheduler)
+    val leaderEpochCache = UnifiedLog.createLeaderEpochCache(
+      logDir, topicPartition, logDirFailureChannel, None, mockTime.scheduler)
     val offsets = new LogLoader(
       logDir,
       topicPartition,
@@ -443,7 +443,7 @@ class LogLoaderTest {
       segments,
       0L,
       0L,
-      leaderEpochCache.toJava,
+      leaderEpochCache,
       stateManager,
       new ConcurrentHashMap[String, Integer],
       false
@@ -540,8 +540,8 @@ class LogLoaderTest {
     val config = new LogConfig(new Properties())
     val logDirFailureChannel = null
     val segments = new LogSegments(topicPartition)
-    val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
-      logDir, topicPartition, logDirFailureChannel, "", None, 
mockTime.scheduler)
+    val leaderEpochCache = UnifiedLog.createLeaderEpochCache(
+      logDir, topicPartition, logDirFailureChannel, None, mockTime.scheduler)
     val offsets = new LogLoader(
       logDir,
       topicPartition,
@@ -553,7 +553,7 @@ class LogLoaderTest {
       segments,
       0L,
       0L,
-      leaderEpochCache.toJava,
+      leaderEpochCache,
       stateManager,
       new ConcurrentHashMap[String, Integer],
       false
@@ -1215,7 +1215,7 @@ class LogLoaderTest {
   @Test
   def testLogRecoversForLeaderEpoch(): Unit = {
     val log = createLog(logDir, new LogConfig(new Properties))
-    val leaderEpochCache = log.leaderEpochCache.get
+    val leaderEpochCache = log.leaderEpochCache
     val firstBatch = singletonRecordsWithLeaderEpoch(value = 
"random".getBytes, leaderEpoch = 1, offset = 0)
     log.appendAsFollower(records = firstBatch)
 
@@ -1237,7 +1237,7 @@ class LogLoaderTest {
 
     // reopen the log and recover from the beginning
     val recoveredLog = createLog(logDir, new LogConfig(new Properties), 
lastShutdownClean = false)
-    val recoveredLeaderEpochCache = recoveredLog.leaderEpochCache.get
+    val recoveredLeaderEpochCache = recoveredLog.leaderEpochCache
 
     // epoch entries should be recovered
     assertEquals(java.util.Arrays.asList(new EpochEntry(1, 0), new 
EpochEntry(2, 1), new EpochEntry(3, 3)), recoveredLeaderEpochCache.epochEntries)
@@ -1633,8 +1633,8 @@ class LogLoaderTest {
     log.logSegments.forEach(segment => segments.add(segment))
     assertEquals(5, segments.firstSegment.get.baseOffset)
 
-    val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
-      logDir, topicPartition, logDirFailureChannel, "", None, 
mockTime.scheduler)
+    val leaderEpochCache = UnifiedLog.createLeaderEpochCache(
+      logDir, topicPartition, logDirFailureChannel, None, mockTime.scheduler)
     val offsets = new LogLoader(
       logDir,
       topicPartition,
@@ -1646,7 +1646,7 @@ class LogLoaderTest {
       segments,
       0L,
       0L,
-      leaderEpochCache.toJava,
+      leaderEpochCache,
       stateManager,
       new ConcurrentHashMap[String, Integer],
       isRemoteLogEnabled
diff --git a/core/src/test/scala/unit/kafka/log/LogTestUtils.scala 
b/core/src/test/scala/unit/kafka/log/LogTestUtils.scala
index 6e27ea75944..e98028ab86f 100644
--- a/core/src/test/scala/unit/kafka/log/LogTestUtils.scala
+++ b/core/src/test/scala/unit/kafka/log/LogTestUtils.scala
@@ -35,7 +35,6 @@ import 
org.apache.kafka.coordinator.transaction.TransactionLogConfig
 import org.apache.kafka.server.config.ServerLogConfigs
 import org.apache.kafka.server.storage.log.FetchIsolation
 import org.apache.kafka.server.util.Scheduler
-import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpointFile
 import 
org.apache.kafka.storage.internals.log.LogConfig.{DEFAULT_REMOTE_LOG_COPY_DISABLE_CONFIG,
 DEFAULT_REMOTE_LOG_DELETE_ON_DISABLE_CONFIG}
 import org.apache.kafka.storage.internals.log.{AbortedTxn, AppendOrigin, 
FetchDataInfo, LazyIndex, LogAppendInfo, LogConfig, LogDirFailureChannel, 
LogFileUtils, LogOffsetsListener, LogSegment, ProducerStateManager, 
ProducerStateManagerConfig, TransactionIndex}
 import org.apache.kafka.storage.log.metrics.BrokerTopicStats
@@ -262,12 +261,6 @@ object LogTestUtils {
   def listProducerSnapshotOffsets(logDir: File): Seq[Long] =
     
ProducerStateManager.listSnapshotFiles(logDir).asScala.map(_.offset).sorted.toSeq
 
-  def assertLeaderEpochCacheEmpty(log: UnifiedLog): Unit = {
-    assertEquals(None, log.leaderEpochCache)
-    assertEquals(None, log.latestEpoch)
-    assertFalse(LeaderEpochCheckpointFile.newFile(log.dir).exists())
-  }
-
   def appendNonTransactionalAsLeader(log: UnifiedLog, numRecords: Int): Unit = 
{
     val simpleRecords = (0 until numRecords).map { seq =>
       new SimpleRecord(s"$seq".getBytes)
diff --git a/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala 
b/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
index 8906c21175e..feb2a9770ec 100755
--- a/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
+++ b/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
@@ -57,11 +57,10 @@ import java.io._
 import java.nio.ByteBuffer
 import java.nio.file.Files
 import java.util.concurrent.{Callable, ConcurrentHashMap, Executors, TimeUnit}
-import java.util.{Optional, OptionalLong, Properties}
+import java.util.{Optional, OptionalInt, OptionalLong, Properties}
 import scala.collection.immutable.SortedSet
 import scala.collection.mutable.ListBuffer
 import scala.jdk.CollectionConverters._
-import scala.jdk.OptionConverters.{RichOptional, RichOptionalInt}
 
 class UnifiedLogTest {
   var config: KafkaConfig = _
@@ -655,23 +654,20 @@ class UnifiedLogTest {
     val records = TestUtils.records(List(new SimpleRecord("a".getBytes), new 
SimpleRecord("b".getBytes)),
       baseOffset = 27)
     appendAsFollower(log, records, leaderEpoch = 19)
-    assertEquals(Some(new EpochEntry(19, 27)),
-      log.leaderEpochCache.flatMap(_.latestEntry.toScala))
+    assertEquals(Optional.of(new EpochEntry(19, 27)), 
log.leaderEpochCache.latestEntry)
     assertEquals(29, log.logEndOffset)
 
     def verifyTruncationClearsEpochCache(epoch: Int, truncationOffset: Long): 
Unit = {
       // Simulate becoming a leader
-      log.maybeAssignEpochStartOffset(leaderEpoch = epoch, startOffset = 
log.logEndOffset)
-      assertEquals(Some(new EpochEntry(epoch, 29)),
-        log.leaderEpochCache.flatMap(_.latestEntry.toScala))
+      log.assignEpochStartOffset(leaderEpoch = epoch, startOffset = 
log.logEndOffset)
+      assertEquals(Optional.of(new EpochEntry(epoch, 29)), 
log.leaderEpochCache.latestEntry)
       assertEquals(29, log.logEndOffset)
 
       // Now we become the follower and truncate to an offset greater
       // than or equal to the log end offset. The trivial epoch entry
       // at the end of the log should be gone
       log.truncateTo(truncationOffset)
-      assertEquals(Some(new EpochEntry(19, 27)),
-        log.leaderEpochCache.flatMap(_.latestEntry.toScala))
+      assertEquals(Optional.of(new EpochEntry(19, 27)), 
log.leaderEpochCache.latestEntry)
       assertEquals(29, log.logEndOffset)
     }
 
@@ -817,11 +813,11 @@ class UnifiedLogTest {
     records.batches.forEach(_.setPartitionLeaderEpoch(0))
 
     val filtered = ByteBuffer.allocate(2048)
-    records.filterTo(new TopicPartition("foo", 0), new RecordFilter(0, 0) {
+    records.filterTo(new RecordFilter(0, 0) {
       override def checkBatchRetention(batch: RecordBatch): 
RecordFilter.BatchRetentionResult =
         new 
RecordFilter.BatchRetentionResult(RecordFilter.BatchRetention.DELETE_EMPTY, 
false)
       override def shouldRetainRecord(recordBatch: RecordBatch, record: 
Record): Boolean = !record.hasKey
-    }, filtered, Int.MaxValue, BufferSupplier.NO_CACHING)
+    }, filtered, BufferSupplier.NO_CACHING)
     filtered.flip()
     val filteredRecords = MemoryRecords.readableRecords(filtered)
 
@@ -859,11 +855,11 @@ class UnifiedLogTest {
     records.batches.forEach(_.setPartitionLeaderEpoch(0))
 
     val filtered = ByteBuffer.allocate(2048)
-    records.filterTo(new TopicPartition("foo", 0), new RecordFilter(0, 0) {
+    records.filterTo(new RecordFilter(0, 0) {
       override def checkBatchRetention(batch: RecordBatch): 
RecordFilter.BatchRetentionResult =
         new 
RecordFilter.BatchRetentionResult(RecordFilter.BatchRetention.RETAIN_EMPTY, 
true)
       override def shouldRetainRecord(recordBatch: RecordBatch, record: 
Record): Boolean = false
-    }, filtered, Int.MaxValue, BufferSupplier.NO_CACHING)
+    }, filtered, BufferSupplier.NO_CACHING)
     filtered.flip()
     val filteredRecords = MemoryRecords.readableRecords(filtered)
 
@@ -903,11 +899,11 @@ class UnifiedLogTest {
     records.batches.forEach(_.setPartitionLeaderEpoch(0))
 
     val filtered = ByteBuffer.allocate(2048)
-    records.filterTo(new TopicPartition("foo", 0), new RecordFilter(0, 0) {
+    records.filterTo(new RecordFilter(0, 0) {
       override def checkBatchRetention(batch: RecordBatch): 
RecordFilter.BatchRetentionResult =
         new 
RecordFilter.BatchRetentionResult(RecordFilter.BatchRetention.DELETE_EMPTY, 
false)
       override def shouldRetainRecord(recordBatch: RecordBatch, record: 
Record): Boolean = !record.hasKey
-    }, filtered, Int.MaxValue, BufferSupplier.NO_CACHING)
+    }, filtered, BufferSupplier.NO_CACHING)
     filtered.flip()
     val filteredRecords = MemoryRecords.readableRecords(filtered)
 
@@ -2060,7 +2056,7 @@ class UnifiedLogTest {
 
     // The cache can be updated directly after a leader change.
     // The new latest offset should reflect the updated epoch.
-    log.maybeAssignEpochStartOffset(2, 2L)
+    log.assignEpochStartOffset(2, 2L)
 
     assertEquals(new OffsetResultHolder(new 
TimestampAndOffset(ListOffsetsResponse.UNKNOWN_TIMESTAMP, 2L, Optional.of(2))),
       log.fetchOffsetByTimestamp(ListOffsetsRequest.LATEST_TIMESTAMP))
@@ -2136,7 +2132,7 @@ class UnifiedLogTest {
         .filter(_ == firstTimestamp)
         .map[TimestampAndOffset](x => new TimestampAndOffset(x, 0L, 
Optional.of(firstLeaderEpoch)))
     
}).when(remoteLogManager).findOffsetByTimestamp(ArgumentMatchers.eq(log.topicPartition),
-      anyLong(), anyLong(), ArgumentMatchers.eq(log.leaderEpochCache.get))
+      anyLong(), anyLong(), ArgumentMatchers.eq(log.leaderEpochCache))
     log._localLogStartOffset = 1
 
     def assertFetchOffsetByTimestamp(expected: Option[TimestampAndOffset], 
timestamp: Long): Unit = {
@@ -2161,7 +2157,7 @@ class UnifiedLogTest {
 
     // The cache can be updated directly after a leader change.
     // The new latest offset should reflect the updated epoch.
-    log.maybeAssignEpochStartOffset(2, 2L)
+    log.assignEpochStartOffset(2, 2L)
     
     assertEquals(new OffsetResultHolder(new 
TimestampAndOffset(ListOffsetsResponse.UNKNOWN_TIMESTAMP, 2L, Optional.of(2))),
       log.fetchOffsetByTimestamp(ListOffsetsRequest.LATEST_TIMESTAMP, 
Some(remoteLogManager)))
@@ -2235,7 +2231,7 @@ class UnifiedLogTest {
         .filter(_ == firstTimestamp)
         .map[TimestampAndOffset](x => new TimestampAndOffset(x, 0L, 
Optional.of(firstLeaderEpoch)))
     
}).when(remoteLogManager).findOffsetByTimestamp(ArgumentMatchers.eq(log.topicPartition),
-      anyLong(), anyLong(), ArgumentMatchers.eq(log.leaderEpochCache.get))
+      anyLong(), anyLong(), ArgumentMatchers.eq(log.leaderEpochCache))
     log._localLogStartOffset = 1
     log._highestOffsetInRemoteStorage = 0
 
@@ -2263,7 +2259,7 @@ class UnifiedLogTest {
 
     // The cache can be updated directly after a leader change.
     // The new latest offset should reflect the updated epoch.
-    log.maybeAssignEpochStartOffset(2, 2L)
+    log.assignEpochStartOffset(2, 2L)
 
     assertEquals(new OffsetResultHolder(new 
TimestampAndOffset(ListOffsetsResponse.UNKNOWN_TIMESTAMP, 2L, Optional.of(2))),
       log.fetchOffsetByTimestamp(ListOffsetsRequest.LATEST_TIMESTAMP, 
Some(remoteLogManager)))
@@ -2578,12 +2574,29 @@ class UnifiedLogTest {
     val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1000, 
indexIntervalBytes = 1, maxMessageBytes = 64 * 1024)
     val log = createLog(logDir, logConfig)
     log.appendAsLeader(TestUtils.records(List(new 
SimpleRecord("foo".getBytes()))), leaderEpoch = 5)
-    assertEquals(Some(5), log.leaderEpochCache.flatMap(_.latestEpoch.toScala))
+    assertEquals(OptionalInt.of(5), log.leaderEpochCache.latestEpoch)
 
     log.appendAsFollower(TestUtils.records(List(new 
SimpleRecord("foo".getBytes())),
       baseOffset = 1L,
       magicValue = RecordVersion.V1.value))
-    assertEquals(None, log.leaderEpochCache.flatMap(_.latestEpoch.toScala))
+    assertEquals(OptionalInt.empty, log.leaderEpochCache.latestEpoch)
+  }
+
+  @Test
+  def testLeaderEpochCacheCreatedAfterMessageFormatUpgrade(): Unit = {
+    val logProps = new Properties()
+    logProps.put(TopicConfig.SEGMENT_BYTES_CONFIG, "1000")
+    logProps.put(TopicConfig.INDEX_INTERVAL_BYTES_CONFIG, "1")
+    logProps.put(TopicConfig.MAX_MESSAGE_BYTES_CONFIG, "65536")
+    val logConfig = new LogConfig(logProps)
+    val log = createLog(logDir, logConfig)
+    log.appendAsLeaderWithRecordVersion(TestUtils.records(List(new 
SimpleRecord("bar".getBytes())),
+      magicValue = RecordVersion.V1.value), leaderEpoch = 5, RecordVersion.V1)
+    assertEquals(None, log.latestEpoch)
+
+    log.appendAsLeader(TestUtils.records(List(new 
SimpleRecord("foo".getBytes())),
+      magicValue = RecordVersion.V2.value), leaderEpoch = 5)
+    assertEquals(Some(5), log.latestEpoch)
   }
 
   @Test
@@ -2671,8 +2684,8 @@ class UnifiedLogTest {
     for (_ <- 0 until 100)
       log.appendAsLeader(createRecords, leaderEpoch = 0)
 
-    log.maybeAssignEpochStartOffset(0, 40)
-    log.maybeAssignEpochStartOffset(1, 90)
+    log.assignEpochStartOffset(0, 40)
+    log.assignEpochStartOffset(1, 90)
 
     // segments are not eligible for deletion if no high watermark has been set
     val numSegments = log.numberOfSegments
@@ -2757,9 +2770,7 @@ class UnifiedLogTest {
     assertEquals(log.logStartOffset, 15)
   }
 
-  def epochCache(log: UnifiedLog): LeaderEpochFileCache = {
-    log.leaderEpochCache.get
-  }
+  def epochCache(log: UnifiedLog): LeaderEpochFileCache = log.leaderEpochCache
 
   @Test
   def shouldDeleteSizeBasedSegments(): Unit = {
@@ -2888,7 +2899,7 @@ class UnifiedLogTest {
     //Given this partition is on leader epoch 72
     val epoch = 72
     val log = createLog(logDir, new LogConfig(new Properties))
-    log.maybeAssignEpochStartOffset(epoch, records.length)
+    log.assignEpochStartOffset(epoch, records.length)
 
     //When appending messages as a leader (i.e. assignOffsets = true)
     for (record <- records)
@@ -3662,14 +3673,9 @@ class UnifiedLogTest {
     assertTrue(newDir.exists())
 
     log.renameDir(newDir.getName, false)
-    assertTrue(log.leaderEpochCache.isEmpty)
+    assertFalse(log.leaderEpochCache.nonEmpty)
     assertTrue(log.partitionMetadataFile.isEmpty)
     assertEquals(0, log.logEndOffset)
-    // verify that records appending can still succeed
-    // even with the uninitialized leaderEpochCache and partitionMetadataFile
-    val records = TestUtils.records(List(new 
SimpleRecord(mockTime.milliseconds, "key".getBytes, "value".getBytes)))
-    log.appendAsLeader(records, leaderEpoch = 0)
-    assertEquals(1, log.logEndOffset)
 
     // verify that the background deletion can succeed
     log.delete()
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala 
b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
index a58342f61e1..a3081f17ed3 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
@@ -67,6 +67,7 @@ import org.apache.kafka.server.storage.log.{FetchIsolation, 
FetchParams, FetchPa
 import org.apache.kafka.server.util.timer.MockTimer
 import org.apache.kafka.server.util.{MockScheduler, MockTime}
 import org.apache.kafka.storage.internals.checkpoint.{LazyOffsetCheckpoints, 
OffsetCheckpointFile, PartitionMetadataFile}
+import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache
 import org.apache.kafka.storage.internals.log.{AppendOrigin, FetchDataInfo, 
LocalLog, LogConfig, LogDirFailureChannel, LogLoader, LogOffsetMetadata, 
LogOffsetSnapshot, LogSegments, LogStartOffsetIncrementReason, 
ProducerStateManager, ProducerStateManagerConfig, RemoteStorageFetchInfo, 
VerificationGuard}
 import org.apache.kafka.storage.log.metrics.BrokerTopicStats
 import org.junit.jupiter.api.Assertions._
@@ -265,7 +266,7 @@ class ReplicaManagerTest {
   }
 
   @Test
-  def testMaybeAddLogDirFetchersWithoutEpochCache(): Unit = {
+  def testMaybeAddLogDirFetchers(): Unit = {
     val dir1 = TestUtils.tempDir()
     val dir2 = TestUtils.tempDir()
     val props = TestUtils.createBrokerConfig(0)
@@ -310,8 +311,6 @@ class ReplicaManagerTest {
 
       partition.createLogIfNotExists(isNew = true, isFutureReplica = true,
         new LazyOffsetCheckpoints(rm.highWatermarkCheckpoints.asJava), None)
-      // remove cache to disable OffsetsForLeaderEpoch API
-      partition.futureLog.get.leaderEpochCache = None
 
       // this method should use hw of future log to create log dir fetcher. 
Otherwise, it causes offset mismatch error
       rm.maybeAddLogDirFetchers(Set(partition), new 
LazyOffsetCheckpoints(rm.highWatermarkCheckpoints.asJava), _ => None)
@@ -2901,8 +2900,8 @@ class ReplicaManagerTest {
     val maxTransactionTimeoutMs = 30000
     val maxProducerIdExpirationMs = 30000
     val segments = new LogSegments(tp)
-    val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
-      logDir, tp, mockLogDirFailureChannel, "", None, time.scheduler)
+    val leaderEpochCache = UnifiedLog.createLeaderEpochCache(
+      logDir, tp, mockLogDirFailureChannel, None, time.scheduler)
     val producerStateManager = new ProducerStateManager(tp, logDir,
       maxTransactionTimeoutMs, new 
ProducerStateManagerConfig(maxProducerIdExpirationMs, true), time)
     val offsets = new LogLoader(
@@ -2916,7 +2915,7 @@ class ReplicaManagerTest {
       segments,
       0L,
       0L,
-      leaderEpochCache.toJava,
+      leaderEpochCache,
       producerStateManager,
       new ConcurrentHashMap[String, Integer],
       false
@@ -4517,7 +4516,7 @@ class ReplicaManagerTest {
     when(mockLog.logStartOffset).thenReturn(endOffset).thenReturn(startOffset)
     when(mockLog.logEndOffset).thenReturn(endOffset)
     when(mockLog.localLogStartOffset()).thenReturn(endOffset - 10)
-    when(mockLog.leaderEpochCache).thenReturn(None)
+    
when(mockLog.leaderEpochCache).thenReturn(mock(classOf[LeaderEpochFileCache]))
     when(mockLog.latestEpoch).thenReturn(Some(0))
     val producerStateManager = mock(classOf[ProducerStateManager])
     when(mockLog.producerStateManager).thenReturn(producerStateManager)
diff --git a/core/src/test/scala/unit/kafka/utils/SchedulerTest.scala 
b/core/src/test/scala/unit/kafka/utils/SchedulerTest.scala
index cb889fe91a0..7afa2178f73 100644
--- a/core/src/test/scala/unit/kafka/utils/SchedulerTest.scala
+++ b/core/src/test/scala/unit/kafka/utils/SchedulerTest.scala
@@ -28,8 +28,6 @@ import org.apache.kafka.storage.log.metrics.BrokerTopicStats
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, Timeout}
 
-import scala.jdk.OptionConverters.RichOption
-
 
 class SchedulerTest {
 
@@ -140,8 +138,8 @@ class SchedulerTest {
     val topicPartition = UnifiedLog.parseTopicPartitionName(logDir)
     val logDirFailureChannel = new LogDirFailureChannel(10)
     val segments = new LogSegments(topicPartition)
-    val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
-      logDir, topicPartition, logDirFailureChannel, "", None, 
mockTime.scheduler)
+    val leaderEpochCache = UnifiedLog.createLeaderEpochCache(
+      logDir, topicPartition, logDirFailureChannel, None, mockTime.scheduler)
     val producerStateManager = new ProducerStateManager(topicPartition, logDir,
       maxTransactionTimeoutMs, new 
ProducerStateManagerConfig(maxProducerIdExpirationMs, false), mockTime)
     val offsets = new LogLoader(
@@ -155,7 +153,7 @@ class SchedulerTest {
       segments,
       0L,
       0L,
-      leaderEpochCache.toJava,
+      leaderEpochCache,
       producerStateManager,
       new ConcurrentHashMap[String, Integer],
       false
diff --git 
a/storage/src/main/java/org/apache/kafka/storage/internals/log/LogLoader.java 
b/storage/src/main/java/org/apache/kafka/storage/internals/log/LogLoader.java
index 1ba58d1b2a9..89780686995 100644
--- 
a/storage/src/main/java/org/apache/kafka/storage/internals/log/LogLoader.java
+++ 
b/storage/src/main/java/org/apache/kafka/storage/internals/log/LogLoader.java
@@ -56,7 +56,7 @@ public class LogLoader {
     private final LogSegments segments;
     private final long logStartOffsetCheckpoint;
     private final long recoveryPointCheckpoint;
-    private final Optional<LeaderEpochFileCache> leaderEpochCache;
+    private final LeaderEpochFileCache leaderEpochCache;
     private final ProducerStateManager producerStateManager;
     private final ConcurrentMap<String, Integer> numRemainingSegments;
     private final boolean isRemoteLogEnabled;
@@ -74,7 +74,7 @@ public class LogLoader {
      * @param segments The {@link LogSegments} instance into which segments 
recovered from disk will be populated
      * @param logStartOffsetCheckpoint The checkpoint of the log start offset
      * @param recoveryPointCheckpoint The checkpoint of the offset at which to 
begin the recovery
-     * @param leaderEpochCache An optional {@link LeaderEpochFileCache} 
instance to be updated during recovery
+     * @param leaderEpochCache A {@link LeaderEpochFileCache} instance to be 
updated during recovery
      * @param producerStateManager The {@link ProducerStateManager} instance 
to be updated during recovery
      * @param numRemainingSegments The remaining segments to be recovered in 
this log keyed by recovery thread name
      * @param isRemoteLogEnabled Boolean flag to indicate whether the remote 
storage is enabled or not
@@ -90,7 +90,7 @@ public class LogLoader {
             LogSegments segments,
             long logStartOffsetCheckpoint,
             long recoveryPointCheckpoint,
-            Optional<LeaderEpochFileCache> leaderEpochCache,
+            LeaderEpochFileCache leaderEpochCache,
             ProducerStateManager producerStateManager,
             ConcurrentMap<String, Integer> numRemainingSegments,
             boolean isRemoteLogEnabled) {
@@ -215,13 +215,13 @@ public class LogLoader {
             recoveryOffsets = new RecoveryOffsets(0L, 0L);
         }
 
-        leaderEpochCache.ifPresent(lec -> 
lec.truncateFromEndAsyncFlush(recoveryOffsets.nextOffset));
+        leaderEpochCache.truncateFromEndAsyncFlush(recoveryOffsets.nextOffset);
         long newLogStartOffset = isRemoteLogEnabled
             ? logStartOffsetCheckpoint
             : Math.max(logStartOffsetCheckpoint, 
segments.firstSegment().get().baseOffset());
 
         // The earliest leader epoch may not be flushed during a hard failure. 
Recover it here.
-        leaderEpochCache.ifPresent(lec -> 
lec.truncateFromStartAsyncFlush(logStartOffsetCheckpoint));
+        leaderEpochCache.truncateFromStartAsyncFlush(logStartOffsetCheckpoint);
 
         // Any segment loading or recovery code must not use 
producerStateManager, so that we can build the full state here
         // from scratch.
@@ -428,7 +428,7 @@ public class LogLoader {
                         "is smaller than logStartOffset {}. " +
                         "This could happen if segment files were deleted from 
the file system.", logEndOffset, logStartOffsetCheckpoint);
                 removeAndDeleteSegmentsAsync(segments.values());
-                
leaderEpochCache.ifPresent(LeaderEpochFileCache::clearAndFlush);
+                leaderEpochCache.clearAndFlush();
                 
producerStateManager.truncateFullyAndStartAt(logStartOffsetCheckpoint);
                 return Optional.empty();
             }
diff --git 
a/storage/src/main/java/org/apache/kafka/storage/internals/log/LogSegment.java 
b/storage/src/main/java/org/apache/kafka/storage/internals/log/LogSegment.java
index 3312d42af02..15cd6c834a0 100644
--- 
a/storage/src/main/java/org/apache/kafka/storage/internals/log/LogSegment.java
+++ 
b/storage/src/main/java/org/apache/kafka/storage/internals/log/LogSegment.java
@@ -465,11 +465,11 @@ public class LogSegment implements Closeable {
      *
      * @param producerStateManager Producer state corresponding to the 
segment's base offset. This is needed to recover
      *                             the transaction index.
-     * @param leaderEpochCache Optionally a cache for updating the leader 
epoch during recovery.
+     * @param leaderEpochCache a cache for updating the leader epoch during 
recovery.
      * @return The number of bytes truncated from the log
      * @throws LogSegmentOffsetOverflowException if the log segment contains 
an offset that causes the index offset to overflow
      */
-    public int recover(ProducerStateManager producerStateManager, 
Optional<LeaderEpochFileCache> leaderEpochCache) throws IOException {
+    public int recover(ProducerStateManager producerStateManager, 
LeaderEpochFileCache leaderEpochCache) throws IOException {
         offsetIndex().reset();
         timeIndex().reset();
         txnIndex.reset();
@@ -495,11 +495,9 @@ public class LogSegment implements Closeable {
                 validBytes += batch.sizeInBytes();
 
                 if (batch.magic() >= RecordBatch.MAGIC_VALUE_V2) {
-                    leaderEpochCache.ifPresent(cache -> {
-                        if (batch.partitionLeaderEpoch() >= 0 &&
-                                (cache.latestEpoch().isEmpty() || 
batch.partitionLeaderEpoch() > cache.latestEpoch().getAsInt()))
-                            cache.assign(batch.partitionLeaderEpoch(), 
batch.baseOffset());
-                    });
+                    if (batch.partitionLeaderEpoch() >= 0 &&
+                            (leaderEpochCache.latestEpoch().isEmpty() || 
batch.partitionLeaderEpoch() > leaderEpochCache.latestEpoch().getAsInt()))
+                        leaderEpochCache.assign(batch.partitionLeaderEpoch(), 
batch.baseOffset());
                     updateProducerState(producerStateManager, batch);
                 }
             }
diff --git 
a/storage/src/test/java/org/apache/kafka/storage/internals/log/LogSegmentTest.java
 
b/storage/src/test/java/org/apache/kafka/storage/internals/log/LogSegmentTest.java
index 5e06c073dc5..dc6a0cfb3aa 100644
--- 
a/storage/src/test/java/org/apache/kafka/storage/internals/log/LogSegmentTest.java
+++ 
b/storage/src/test/java/org/apache/kafka/storage/internals/log/LogSegmentTest.java
@@ -440,7 +440,7 @@ public class LogSegmentTest {
             }
             File indexFile = seg.offsetIndexFile();
             writeNonsenseToFile(indexFile, 5, (int) indexFile.length());
-            seg.recover(newProducerStateManager(), Optional.empty());
+            seg.recover(newProducerStateManager(), 
mock(LeaderEpochFileCache.class));
             for (int i = 0; i < 100; i++) {
                 Iterable<Record> records = seg.read(i, 1, Optional.of((long) 
seg.size()), true).records.records();
                 assertEquals(i, records.iterator().next().offset());
@@ -482,7 +482,7 @@ public class LogSegmentTest {
                 107L, endTxnRecords(ControlRecordType.COMMIT, pid1, 
producerEpoch, 107L));
 
             ProducerStateManager stateManager = newProducerStateManager();
-            segment.recover(stateManager, Optional.empty());
+            segment.recover(stateManager, mock(LeaderEpochFileCache.class));
             assertEquals(108L, stateManager.mapEndOffset());
 
             List<AbortedTxn> abortedTxns = segment.txnIndex().allAbortedTxns();
@@ -498,7 +498,7 @@ public class LogSegmentTest {
             stateManager.loadProducerEntry(new ProducerStateEntry(pid2, 
producerEpoch, 0,
                 RecordBatch.NO_TIMESTAMP, OptionalLong.of(75L),
                 Optional.of(new BatchMetadata(10, 10L, 5, 
RecordBatch.NO_TIMESTAMP))));
-            segment.recover(stateManager, Optional.empty());
+            segment.recover(stateManager, mock(LeaderEpochFileCache.class));
             assertEquals(108L, stateManager.mapEndOffset());
 
             abortedTxns = segment.txnIndex().allAbortedTxns();
@@ -533,7 +533,7 @@ public class LogSegmentTest {
             seg.append(111L, RecordBatch.NO_TIMESTAMP, 110L, 
MemoryRecords.withRecords(110L, Compression.NONE, 2,
                 new SimpleRecord("a".getBytes()), new 
SimpleRecord("b".getBytes())));
 
-            seg.recover(newProducerStateManager(), Optional.of(cache));
+            seg.recover(newProducerStateManager(), cache);
             assertEquals(Arrays.asList(
                 new EpochEntry(0, 104L),
                 new EpochEntry(1, 106L),
@@ -570,7 +570,7 @@ public class LogSegmentTest {
             }
             File timeIndexFile = seg.timeIndexFile();
             writeNonsenseToFile(timeIndexFile, 5, (int) 
timeIndexFile.length());
-            seg.recover(newProducerStateManager(), Optional.empty());
+            seg.recover(newProducerStateManager(), 
mock(LeaderEpochFileCache.class));
             for (int i = 0; i < 100; i++) {
                 assertEquals(i, seg.findOffsetByTimestamp(i * 10, 
0L).get().offset);
                 if (i < 99) {
@@ -597,7 +597,7 @@ public class LogSegmentTest {
                 FileRecords.LogOffsetPosition recordPosition = 
seg.log().searchForOffsetWithSize(offsetToBeginCorruption, 0);
                 int position = recordPosition.position + 
TestUtils.RANDOM.nextInt(15);
                 writeNonsenseToFile(seg.log().file(), position, (int) 
(seg.log().file().length() - position));
-                seg.recover(newProducerStateManager(), Optional.empty());
+                seg.recover(newProducerStateManager(), 
mock(LeaderEpochFileCache.class));
 
                 List<Long> expectList = new ArrayList<>();
                 for (long j = 0; j < offsetToBeginCorruption; j++) {
diff --git 
a/storage/src/test/java/org/apache/kafka/tiered/storage/TieredStorageTestContext.java
 
b/storage/src/test/java/org/apache/kafka/tiered/storage/TieredStorageTestContext.java
index ff61e29d93c..2fdb9483fe6 100644
--- 
a/storage/src/test/java/org/apache/kafka/tiered/storage/TieredStorageTestContext.java
+++ 
b/storage/src/test/java/org/apache/kafka/tiered/storage/TieredStorageTestContext.java
@@ -302,11 +302,7 @@ public final class TieredStorageTestContext implements 
AutoCloseable {
 
     // unused now, but it can be reused later as this is an utility method.
     public Optional<LeaderEpochFileCache> leaderEpochFileCache(int brokerId, 
TopicPartition partition) {
-        Optional<UnifiedLog> unifiedLogOpt = log(brokerId, partition);
-        if (unifiedLogOpt.isPresent() && 
unifiedLogOpt.get().leaderEpochCache().isDefined()) {
-            return Optional.of(unifiedLogOpt.get().leaderEpochCache().get());
-        }
-        return Optional.empty();
+        return log(brokerId, partition).map(log -> log.leaderEpochCache());
     }
 
     public List<LocalTieredStorage> remoteStorageManagers() {
diff --git 
a/storage/src/test/java/org/apache/kafka/tiered/storage/actions/ExpectLeaderEpochCheckpointAction.java
 
b/storage/src/test/java/org/apache/kafka/tiered/storage/actions/ExpectLeaderEpochCheckpointAction.java
index 10231ad06ff..da683979983 100644
--- 
a/storage/src/test/java/org/apache/kafka/tiered/storage/actions/ExpectLeaderEpochCheckpointAction.java
+++ 
b/storage/src/test/java/org/apache/kafka/tiered/storage/actions/ExpectLeaderEpochCheckpointAction.java
@@ -30,8 +30,6 @@ import java.util.Optional;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.atomic.AtomicReference;
 
-import scala.Option;
-
 public final class ExpectLeaderEpochCheckpointAction implements 
TieredStorageTestAction {
 
     private final Integer brokerId;
@@ -56,10 +54,8 @@ public final class ExpectLeaderEpochCheckpointAction 
implements TieredStorageTes
             EpochEntry earliestEntry = null;
             Optional<UnifiedLog> log = context.log(brokerId, partition);
             if (log.isPresent()) {
-                Option<LeaderEpochFileCache> leaderEpochCache = 
log.get().leaderEpochCache();
-                if (leaderEpochCache.isDefined()) {
-                    earliestEntry = 
leaderEpochCache.get().earliestEntry().orElse(null);
-                }
+                LeaderEpochFileCache leaderEpochCache = 
log.get().leaderEpochCache();
+                earliestEntry = leaderEpochCache.earliestEntry().orElse(null);
             }
             earliestEntryOpt.set(earliestEntry);
             return earliestEntry != null && beginEpoch == earliestEntry.epoch


Reply via email to