This is an automated email from the ASF dual-hosted git repository.

jsancio pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 4a8a0637e07 KAFKA-18723; Better handle invalid records during 
replication (#18852)
4a8a0637e07 is described below

commit 4a8a0637e07734779b40ba9785842311144f922c
Author: José Armando García Sancio <[email protected]>
AuthorDate: Tue Feb 25 20:09:19 2025 -0500

    KAFKA-18723; Better handle invalid records during replication (#18852)
    
    For the KRaft implementation there is a race between the network thread,
    which read bytes in the log segments, and the KRaft driver thread, which
    truncates the log and appends records to the log. This race can cause
    the network thread to send corrupted records or inconsistent records.
    The corrupted records case is handle by catching and logging the
    CorruptRecordException. The inconsistent records case is handle by only
    appending record batches who's partition leader epoch is less than or
    equal to the fetching replica's epoch and the epoch didn't change
    between the request and response.
    
    For the ISR implementation there is also a race between the network
    thread and the replica fetcher thread, which truncates the log and
    appends records to the log. This race can cause the network thread send
    corrupted records or inconsistent records. The replica fetcher thread
    already handles the corrupted record case. The inconsistent records case
    is handle by only appending record batches who's partition leader epoch
    is less than or equal to the leader epoch in the FETCH request.
    
    Reviewers: Jun Rao <[email protected]>, Alyssa Huang <[email protected]>, 
Chia-Ping Tsai <[email protected]>
---
 build.gradle                                       |   8 +
 .../kafka/common/record/DefaultRecordBatch.java    |   3 +-
 .../apache/kafka/common/record/MemoryRecords.java  |   6 +-
 .../common/record/ArbitraryMemoryRecords.java      |  39 +++
 .../record/InvalidMemoryRecordsProvider.java       | 132 ++++++++
 core/src/main/scala/kafka/cluster/Partition.scala  |  20 +-
 core/src/main/scala/kafka/log/UnifiedLog.scala     | 137 +++++---
 .../main/scala/kafka/raft/KafkaMetadataLog.scala   |  18 +-
 .../scala/kafka/server/AbstractFetcherThread.scala |  21 +-
 .../kafka/server/ReplicaAlterLogDirsThread.scala   |  11 +-
 .../scala/kafka/server/ReplicaFetcherThread.scala  |  11 +-
 .../scala/kafka/raft/KafkaMetadataLogTest.scala    |  95 +++++-
 .../scala/unit/kafka/cluster/PartitionTest.scala   | 102 ++++--
 .../test/scala/unit/kafka/log/LogCleanerTest.scala |  22 +-
 .../scala/unit/kafka/log/LogConcurrencyTest.scala  |  11 +-
 .../test/scala/unit/kafka/log/LogLoaderTest.scala  |  32 +-
 .../test/scala/unit/kafka/log/UnifiedLogTest.scala | 371 ++++++++++++++++-----
 .../kafka/server/AbstractFetcherManagerTest.scala  |   9 +-
 .../kafka/server/AbstractFetcherThreadTest.scala   |  85 ++++-
 .../unit/kafka/server/MockFetcherThread.scala      |  35 +-
 .../kafka/server/ReplicaFetcherThreadTest.scala    |  33 +-
 .../unit/kafka/server/ReplicaManagerTest.scala     |   9 +-
 .../jmh/fetcher/ReplicaFetcherThreadBenchmark.java |   8 +-
 .../partition/PartitionMakeFollowerBenchmark.java  |   2 +-
 .../org/apache/kafka/raft/KafkaRaftClient.java     |  48 ++-
 .../java/org/apache/kafka/raft/ReplicatedLog.java  |   9 +-
 .../kafka/raft/KafkaRaftClientFetchTest.java       | 152 +++++++++
 .../test/java/org/apache/kafka/raft/MockLog.java   |  46 ++-
 .../java/org/apache/kafka/raft/MockLogTest.java    | 125 ++++++-
 29 files changed, 1298 insertions(+), 302 deletions(-)

diff --git a/build.gradle b/build.gradle
index a4bed55f1d1..9d04500a286 100644
--- a/build.gradle
+++ b/build.gradle
@@ -1037,6 +1037,7 @@ project(':core') {
     testImplementation project(':test-common:test-common-util')
     testImplementation libs.bcpkix
     testImplementation libs.mockitoCore
+    testImplementation libs.jqwik
     testImplementation(libs.apacheda) {
       exclude group: 'xml-apis', module: 'xml-apis'
       // `mina-core` is a transitive dependency for `apacheds` and `apacheda`.
@@ -1231,6 +1232,12 @@ project(':core') {
     )
   }
 
+  test {
+    useJUnitPlatform {
+      includeEngines 'jqwik', 'junit-jupiter'
+    }
+  }
+
   tasks.create(name: "copyDependantTestLibs", type: Copy) {
     from (configurations.testRuntimeClasspath) {
       include('*.jar')
@@ -1802,6 +1809,7 @@ project(':clients') {
     testImplementation libs.jacksonJakartarsJsonProvider
     testImplementation libs.jose4j
     testImplementation libs.junitJupiter
+    testImplementation libs.jqwik
     testImplementation libs.spotbugs
     testImplementation libs.mockitoCore
     testImplementation libs.mockitoJunitJupiter // supports MockitoExtension
diff --git 
a/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java 
b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java
index 912c3490f43..d6e9cc6bd7f 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java
@@ -159,7 +159,7 @@ public class DefaultRecordBatch extends AbstractRecordBatch 
implements MutableRe
 
     /**
      * Gets the base timestamp of the batch which is used to calculate the 
record timestamps from the deltas.
-     * 
+     *
      * @return The base timestamp
      */
     public long baseTimestamp() {
@@ -502,6 +502,7 @@ public class DefaultRecordBatch extends AbstractRecordBatch 
implements MutableRe
     public String toString() {
         return "RecordBatch(magic=" + magic() + ", offsets=[" + baseOffset() + 
", " + lastOffset() + "], " +
                 "sequence=[" + baseSequence() + ", " + lastSequence() + "], " +
+                "partitionLeaderEpoch=" + partitionLeaderEpoch() + ", " +
                 "isTransactional=" + isTransactional() + ", isControlBatch=" + 
isControlBatch() + ", " +
                 "compression=" + compressionType() + ", timestampType=" + 
timestampType() + ", crc=" + checksum() + ")";
     }
diff --git 
a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java 
b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
index 1aad97d5920..c06188edf22 100644
--- a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
+++ b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
@@ -32,9 +32,6 @@ import org.apache.kafka.common.utils.ByteBufferOutputStream;
 import org.apache.kafka.common.utils.CloseableIterator;
 import org.apache.kafka.common.utils.Utils;
 
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
 import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.nio.channels.GatheringByteChannel;
@@ -49,7 +46,6 @@ import java.util.Objects;
  * or one of the {@link #builder(ByteBuffer, byte, Compression, TimestampType, 
long)} variants.
  */
 public class MemoryRecords extends AbstractRecords {
-    private static final Logger log = 
LoggerFactory.getLogger(MemoryRecords.class);
     public static final MemoryRecords EMPTY = 
MemoryRecords.readableRecords(ByteBuffer.allocate(0));
 
     private final ByteBuffer buffer;
@@ -596,7 +592,7 @@ public class MemoryRecords extends AbstractRecords {
         return withRecords(magic, initialOffset, compression, 
TimestampType.CREATE_TIME, records);
     }
 
-    public static MemoryRecords withRecords(long initialOffset, Compression 
compression, Integer partitionLeaderEpoch, SimpleRecord... records) {
+    public static MemoryRecords withRecords(long initialOffset, Compression 
compression, int partitionLeaderEpoch, SimpleRecord... records) {
         return withRecords(RecordBatch.CURRENT_MAGIC_VALUE, initialOffset, 
compression, TimestampType.CREATE_TIME, RecordBatch.NO_PRODUCER_ID,
                 RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, 
partitionLeaderEpoch, false, records);
     }
diff --git 
a/clients/src/test/java/org/apache/kafka/common/record/ArbitraryMemoryRecords.java
 
b/clients/src/test/java/org/apache/kafka/common/record/ArbitraryMemoryRecords.java
new file mode 100644
index 00000000000..30eec866a6c
--- /dev/null
+++ 
b/clients/src/test/java/org/apache/kafka/common/record/ArbitraryMemoryRecords.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.common.record;
+
+import net.jqwik.api.Arbitraries;
+import net.jqwik.api.Arbitrary;
+import net.jqwik.api.ArbitrarySupplier;
+
+import java.nio.ByteBuffer;
+import java.util.Random;
+
+public final class ArbitraryMemoryRecords implements 
ArbitrarySupplier<MemoryRecords> {
+    @Override
+    public Arbitrary<MemoryRecords> get() {
+        return 
Arbitraries.randomValue(ArbitraryMemoryRecords::buildRandomRecords);
+    }
+
+    private static MemoryRecords buildRandomRecords(Random random) {
+        int size = random.nextInt(128) + 1;
+        byte[] bytes = new byte[size];
+        random.nextBytes(bytes);
+
+        return MemoryRecords.readableRecords(ByteBuffer.wrap(bytes));
+    }
+}
diff --git 
a/clients/src/test/java/org/apache/kafka/common/record/InvalidMemoryRecordsProvider.java
 
b/clients/src/test/java/org/apache/kafka/common/record/InvalidMemoryRecordsProvider.java
new file mode 100644
index 00000000000..0f9446a6391
--- /dev/null
+++ 
b/clients/src/test/java/org/apache/kafka/common/record/InvalidMemoryRecordsProvider.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.common.record;
+
+import org.apache.kafka.common.errors.CorruptRecordException;
+
+import org.junit.jupiter.api.extension.ExtensionContext;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.ArgumentsProvider;
+
+import java.nio.ByteBuffer;
+import java.util.Optional;
+import java.util.stream.Stream;
+
+public final class InvalidMemoryRecordsProvider implements ArgumentsProvider {
+    // Use a baseOffset that's not zero so that it is less likely to match the 
LEO
+    private static final long BASE_OFFSET = 1234;
+    private static final int EPOCH = 4321;
+
+    /**
+     * Returns a stream of arguments for invalid memory records and the 
expected exception.
+     *
+     * The first object in the {@code Arguments} is a {@code MemoryRecords}.
+     *
+     * The second object in the {@code Arguments} is an {@code 
Optional<Class<Exception>>} which is
+     * the expected exception from the log layer.
+     */
+    @Override
+    public Stream<? extends Arguments> provideArguments(ExtensionContext 
context) {
+        return Stream.of(
+            Arguments.of(MemoryRecords.readableRecords(notEnoughBytes()), 
Optional.empty()),
+            Arguments.of(MemoryRecords.readableRecords(recordsSizeTooSmall()), 
Optional.of(CorruptRecordException.class)),
+            
Arguments.of(MemoryRecords.readableRecords(notEnoughBytesToMagic()), 
Optional.empty()),
+            Arguments.of(MemoryRecords.readableRecords(negativeMagic()), 
Optional.of(CorruptRecordException.class)),
+            Arguments.of(MemoryRecords.readableRecords(largeMagic()), 
Optional.of(CorruptRecordException.class)),
+            
Arguments.of(MemoryRecords.readableRecords(lessBytesThanRecordSize()), 
Optional.empty())
+        );
+    }
+
+    private static ByteBuffer notEnoughBytes() {
+        var buffer = ByteBuffer.allocate(Records.LOG_OVERHEAD - 1);
+        buffer.limit(buffer.capacity());
+
+        return buffer;
+    }
+
+    private static ByteBuffer recordsSizeTooSmall() {
+        var buffer = ByteBuffer.allocate(256);
+        // Write the base offset
+        buffer.putLong(BASE_OFFSET);
+        // Write record size
+        buffer.putInt(LegacyRecord.RECORD_OVERHEAD_V0 - 1);
+        buffer.position(0);
+        buffer.limit(buffer.capacity());
+
+        return buffer;
+    }
+
+    private static ByteBuffer notEnoughBytesToMagic() {
+        var buffer = ByteBuffer.allocate(256);
+        // Write the base offset
+        buffer.putLong(BASE_OFFSET);
+        // Write record size
+        buffer.putInt(buffer.capacity() - Records.LOG_OVERHEAD);
+        buffer.position(0);
+        buffer.limit(Records.HEADER_SIZE_UP_TO_MAGIC - 1);
+
+        return buffer;
+    }
+
+    private static ByteBuffer negativeMagic() {
+        var buffer = ByteBuffer.allocate(256);
+        // Write the base offset
+        buffer.putLong(BASE_OFFSET);
+        // Write record size
+        buffer.putInt(buffer.capacity() - Records.LOG_OVERHEAD);
+        // Write the epoch
+        buffer.putInt(EPOCH);
+        // Write magic
+        buffer.put((byte) -1);
+        buffer.position(0);
+        buffer.limit(buffer.capacity());
+
+        return buffer;
+    }
+
+    private static ByteBuffer largeMagic() {
+        var buffer = ByteBuffer.allocate(256);
+        // Write the base offset
+        buffer.putLong(BASE_OFFSET);
+        // Write record size
+        buffer.putInt(buffer.capacity() - Records.LOG_OVERHEAD);
+        // Write the epoch
+        buffer.putInt(EPOCH);
+        // Write magic
+        buffer.put((byte) (RecordBatch.CURRENT_MAGIC_VALUE + 1));
+        buffer.position(0);
+        buffer.limit(buffer.capacity());
+
+        return buffer;
+    }
+
+    private static ByteBuffer lessBytesThanRecordSize() {
+        var buffer = ByteBuffer.allocate(256);
+        // Write the base offset
+        buffer.putLong(BASE_OFFSET);
+        // Write record size
+        buffer.putInt(buffer.capacity() - Records.LOG_OVERHEAD);
+        // Write the epoch
+        buffer.putInt(EPOCH);
+        // Write magic
+        buffer.put(RecordBatch.CURRENT_MAGIC_VALUE);
+        buffer.position(0);
+        buffer.limit(buffer.capacity() - Records.LOG_OVERHEAD - 1);
+
+        return buffer;
+    }
+}
diff --git a/core/src/main/scala/kafka/cluster/Partition.scala 
b/core/src/main/scala/kafka/cluster/Partition.scala
index 2a10afb3b5d..b2394cbb944 100755
--- a/core/src/main/scala/kafka/cluster/Partition.scala
+++ b/core/src/main/scala/kafka/cluster/Partition.scala
@@ -1302,27 +1302,35 @@ class Partition(val topicPartition: TopicPartition,
     }
   }
 
-  private def doAppendRecordsToFollowerOrFutureReplica(records: MemoryRecords, 
isFuture: Boolean): Option[LogAppendInfo] = {
+  private def doAppendRecordsToFollowerOrFutureReplica(
+    records: MemoryRecords,
+    isFuture: Boolean,
+    partitionLeaderEpoch: Int
+  ): Option[LogAppendInfo] = {
     if (isFuture) {
       // The read lock is needed to handle race condition if request handler 
thread tries to
       // remove future replica after receiving AlterReplicaLogDirsRequest.
       inReadLock(leaderIsrUpdateLock) {
         // Note the replica may be undefined if it is removed by a 
non-ReplicaAlterLogDirsThread before
         // this method is called
-        futureLog.map { _.appendAsFollower(records) }
+        futureLog.map { _.appendAsFollower(records, partitionLeaderEpoch) }
       }
     } else {
       // The lock is needed to prevent the follower replica from being updated 
while ReplicaAlterDirThread
       // is executing maybeReplaceCurrentWithFutureReplica() to replace 
follower replica with the future replica.
       futureLogLock.synchronized {
-        Some(localLogOrException.appendAsFollower(records))
+        Some(localLogOrException.appendAsFollower(records, 
partitionLeaderEpoch))
       }
     }
   }
 
-  def appendRecordsToFollowerOrFutureReplica(records: MemoryRecords, isFuture: 
Boolean): Option[LogAppendInfo] = {
+  def appendRecordsToFollowerOrFutureReplica(
+    records: MemoryRecords,
+    isFuture: Boolean,
+    partitionLeaderEpoch: Int
+  ): Option[LogAppendInfo] = {
     try {
-      doAppendRecordsToFollowerOrFutureReplica(records, isFuture)
+      doAppendRecordsToFollowerOrFutureReplica(records, isFuture, 
partitionLeaderEpoch)
     } catch {
       case e: UnexpectedAppendOffsetException =>
         val log = if (isFuture) futureLocalLogOrException else 
localLogOrException
@@ -1340,7 +1348,7 @@ class Partition(val topicPartition: TopicPartition,
           info(s"Unexpected offset in append to $topicPartition. First offset 
${e.firstOffset} is less than log start offset ${log.logStartOffset}." +
                s" Since this is the first record to be appended to the 
$replicaName's log, will start the log from offset ${e.firstOffset}.")
           truncateFullyAndStartAt(e.firstOffset, isFuture)
-          doAppendRecordsToFollowerOrFutureReplica(records, isFuture)
+          doAppendRecordsToFollowerOrFutureReplica(records, isFuture, 
partitionLeaderEpoch)
         } else
           throw e
     }
diff --git a/core/src/main/scala/kafka/log/UnifiedLog.scala 
b/core/src/main/scala/kafka/log/UnifiedLog.scala
index b3c447faac4..fbacbe0af1a 100644
--- a/core/src/main/scala/kafka/log/UnifiedLog.scala
+++ b/core/src/main/scala/kafka/log/UnifiedLog.scala
@@ -669,6 +669,7 @@ class UnifiedLog(@volatile var logStartOffset: Long,
    * Append this message set to the active segment of the local log, assigning 
offsets and Partition Leader Epochs
    *
    * @param records The records to append
+   * @param leaderEpoch the epoch of the replica appending
    * @param origin Declares the origin of the append which affects required 
validations
    * @param requestLocal request local instance
    * @throws KafkaStorageException If the append fails due to an I/O error.
@@ -699,14 +700,15 @@ class UnifiedLog(@volatile var logStartOffset: Long,
    * Append this message set to the active segment of the local log without 
assigning offsets or Partition Leader Epochs
    *
    * @param records The records to append
+   * @param leaderEpoch the epoch of the replica appending
    * @throws KafkaStorageException If the append fails due to an I/O error.
    * @return Information about the appended messages including the first and 
last offset.
    */
-  def appendAsFollower(records: MemoryRecords): LogAppendInfo = {
+  def appendAsFollower(records: MemoryRecords, leaderEpoch: Int): 
LogAppendInfo = {
     append(records,
       origin = AppendOrigin.REPLICATION,
       validateAndAssignOffsets = false,
-      leaderEpoch = -1,
+      leaderEpoch = leaderEpoch,
       requestLocal = None,
       verificationGuard = VerificationGuard.SENTINEL,
       // disable to check the validation of record size since the record is 
already accepted by leader.
@@ -1085,63 +1087,85 @@ class UnifiedLog(@volatile var logStartOffset: Long,
     var shallowOffsetOfMaxTimestamp = -1L
     var readFirstMessage = false
     var lastOffsetOfFirstBatch = -1L
+    var skipRemainingBatches = false
 
     records.batches.forEach { batch =>
       if (origin == AppendOrigin.RAFT_LEADER && batch.partitionLeaderEpoch != 
leaderEpoch) {
-        throw new InvalidRecordException("Append from Raft leader did not set 
the batch epoch correctly")
+        throw new InvalidRecordException(
+          s"Append from Raft leader did not set the batch epoch correctly, 
expected $leaderEpoch " +
+          s"but the batch has ${batch.partitionLeaderEpoch}"
+        )
       }
       // we only validate V2 and higher to avoid potential compatibility 
issues with older clients
-      if (batch.magic >= RecordBatch.MAGIC_VALUE_V2 && origin == 
AppendOrigin.CLIENT && batch.baseOffset != 0)
+      if (batch.magic >= RecordBatch.MAGIC_VALUE_V2 && origin == 
AppendOrigin.CLIENT && batch.baseOffset != 0) {
         throw new InvalidRecordException(s"The baseOffset of the record batch 
in the append to $topicPartition should " +
           s"be 0, but it is ${batch.baseOffset}")
-
-      // update the first offset if on the first message. For magic versions 
older than 2, we use the last offset
-      // to avoid the need to decompress the data (the last offset can be 
obtained directly from the wrapper message).
-      // For magic version 2, we can get the first offset directly from the 
batch header.
-      // When appending to the leader, we will update LogAppendInfo.baseOffset 
with the correct value. In the follower
-      // case, validation will be more lenient.
-      // Also indicate whether we have the accurate first offset or not
-      if (!readFirstMessage) {
-        if (batch.magic >= RecordBatch.MAGIC_VALUE_V2)
-          firstOffset = batch.baseOffset
-        lastOffsetOfFirstBatch = batch.lastOffset
-        readFirstMessage = true
       }
 
-      // check that offsets are monotonically increasing
-      if (lastOffset >= batch.lastOffset)
-        monotonic = false
-
-      // update the last offset seen
-      lastOffset = batch.lastOffset
-      lastLeaderEpoch = batch.partitionLeaderEpoch
-
-      // Check if the message sizes are valid.
-      val batchSize = batch.sizeInBytes
-      if (!ignoreRecordSize && batchSize > config.maxMessageSize) {
-        
brokerTopicStats.topicStats(topicPartition.topic).bytesRejectedRate.mark(records.sizeInBytes)
-        
brokerTopicStats.allTopicsStats.bytesRejectedRate.mark(records.sizeInBytes)
-        throw new RecordTooLargeException(s"The record batch size in the 
append to $topicPartition is $batchSize bytes " +
-          s"which exceeds the maximum configured value of 
${config.maxMessageSize}.")
-      }
+      /* During replication of uncommitted data it is possible for the remote 
replica to send record batches after it lost
+       * leadership. This can happen if sending FETCH responses is slow. There 
is a race between sending the FETCH
+       * response and the replica truncating and appending to the log. The 
replicating replica resolves this issue by only
+       * persisting up to the current leader epoch used in the fetch request. 
See KAFKA-18723 for more details.
+       */
+      skipRemainingBatches = skipRemainingBatches || 
hasHigherPartitionLeaderEpoch(batch, origin, leaderEpoch)
+      if (skipRemainingBatches) {
+        info(
+          s"Skipping batch $batch from an origin of $origin because its 
partition leader epoch " +
+          s"${batch.partitionLeaderEpoch} is higher than the replica's current 
leader epoch " +
+          s"$leaderEpoch"
+        )
+      } else {
+        // update the first offset if on the first message. For magic versions 
older than 2, we use the last offset
+        // to avoid the need to decompress the data (the last offset can be 
obtained directly from the wrapper message).
+        // For magic version 2, we can get the first offset directly from the 
batch header.
+        // When appending to the leader, we will update 
LogAppendInfo.baseOffset with the correct value. In the follower
+        // case, validation will be more lenient.
+        // Also indicate whether we have the accurate first offset or not
+        if (!readFirstMessage) {
+          if (batch.magic >= RecordBatch.MAGIC_VALUE_V2) {
+            firstOffset = batch.baseOffset
+          }
+          lastOffsetOfFirstBatch = batch.lastOffset
+          readFirstMessage = true
+        }
 
-      // check the validity of the message by checking CRC
-      if (!batch.isValid) {
-        brokerTopicStats.allTopicsStats.invalidMessageCrcRecordsPerSec.mark()
-        throw new CorruptRecordException(s"Record is corrupt (stored crc = 
${batch.checksum()}) in topic partition $topicPartition.")
-      }
+        // check that offsets are monotonically increasing
+        if (lastOffset >= batch.lastOffset) {
+          monotonic = false
+        }
 
-      if (batch.maxTimestamp > maxTimestamp) {
-        maxTimestamp = batch.maxTimestamp
-        shallowOffsetOfMaxTimestamp = lastOffset
-      }
+        // update the last offset seen
+        lastOffset = batch.lastOffset
+        lastLeaderEpoch = batch.partitionLeaderEpoch
+
+        // Check if the message sizes are valid.
+        val batchSize = batch.sizeInBytes
+        if (!ignoreRecordSize && batchSize > config.maxMessageSize) {
+          
brokerTopicStats.topicStats(topicPartition.topic).bytesRejectedRate.mark(records.sizeInBytes)
+          
brokerTopicStats.allTopicsStats.bytesRejectedRate.mark(records.sizeInBytes)
+          throw new RecordTooLargeException(s"The record batch size in the 
append to $topicPartition is $batchSize bytes " +
+            s"which exceeds the maximum configured value of 
${config.maxMessageSize}.")
+        }
 
-      validBytesCount += batchSize
+        // check the validity of the message by checking CRC
+        if (!batch.isValid) {
+          brokerTopicStats.allTopicsStats.invalidMessageCrcRecordsPerSec.mark()
+          throw new CorruptRecordException(s"Record is corrupt (stored crc = 
${batch.checksum()}) in topic partition $topicPartition.")
+        }
 
-      val batchCompression = CompressionType.forId(batch.compressionType.id)
-      // sourceCompression is only used on the leader path, which only 
contains one batch if version is v2 or messages are compressed
-      if (batchCompression != CompressionType.NONE)
-        sourceCompression = batchCompression
+        if (batch.maxTimestamp > maxTimestamp) {
+          maxTimestamp = batch.maxTimestamp
+          shallowOffsetOfMaxTimestamp = lastOffset
+        }
+
+        validBytesCount += batchSize
+
+        val batchCompression = CompressionType.forId(batch.compressionType.id)
+        // sourceCompression is only used on the leader path, which only 
contains one batch if version is v2 or messages are compressed
+        if (batchCompression != CompressionType.NONE) {
+          sourceCompression = batchCompression
+        }
+      }
     }
 
     if (requireOffsetsMonotonic && !monotonic)
@@ -1158,6 +1182,25 @@ class UnifiedLog(@volatile var logStartOffset: Long,
       validBytesCount, lastOffsetOfFirstBatch, 
Collections.emptyList[RecordError], LeaderHwChange.NONE)
   }
 
+  /**
+   * Return true if the record batch has a higher leader epoch than the 
specified leader epoch
+   *
+   * @param batch the batch to validate
+   * @param origin the reason for appending the record batch
+   * @param leaderEpoch the epoch to compare
+   * @return true if the append reason is replication and the batch's 
partition leader epoch is
+   *         greater than the specified leaderEpoch, otherwise false
+   */
+  private def hasHigherPartitionLeaderEpoch(
+    batch: RecordBatch,
+    origin: AppendOrigin,
+    leaderEpoch: Int
+  ): Boolean = {
+    origin == AppendOrigin.REPLICATION &&
+    batch.partitionLeaderEpoch() != RecordBatch.NO_PARTITION_LEADER_EPOCH &&
+    batch.partitionLeaderEpoch() > leaderEpoch
+  }
+
   /**
    * Trim any invalid bytes from the end of this message set (if there are any)
    *
@@ -1295,7 +1338,7 @@ class UnifiedLog(@volatile var logStartOffset: Long,
 
           val asyncOffsetReadFutureHolder = 
remoteOffsetReader.get.asyncOffsetRead(topicPartition, targetTimestamp,
             logStartOffset, leaderEpochCache, () => 
searchOffsetInLocalLog(targetTimestamp, localLogStartOffset()))
-          
+
           new OffsetResultHolder(Optional.empty(), 
Optional.of(asyncOffsetReadFutureHolder))
         } else {
           new OffsetResultHolder(searchOffsetInLocalLog(targetTimestamp, 
logStartOffset))
diff --git a/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala 
b/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala
index 3f6c2044df5..be03e0723af 100644
--- a/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala
+++ b/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala
@@ -25,6 +25,7 @@ import kafka.raft.KafkaMetadataLog.UnknownReason
 import kafka.utils.Logging
 import org.apache.kafka.common.config.TopicConfig
 import org.apache.kafka.common.errors.InvalidConfigurationException
+import org.apache.kafka.common.errors.CorruptRecordException
 import org.apache.kafka.common.record.{MemoryRecords, Records}
 import org.apache.kafka.common.utils.{Time, Utils}
 import org.apache.kafka.common.{KafkaException, TopicPartition, Uuid}
@@ -89,8 +90,9 @@ final class KafkaMetadataLog private (
   }
 
   override def appendAsLeader(records: Records, epoch: Int): LogAppendInfo = {
-    if (records.sizeInBytes == 0)
+    if (records.sizeInBytes == 0) {
       throw new IllegalArgumentException("Attempt to append an empty record 
set")
+    }
 
     handleAndConvertLogAppendInfo(
       log.appendAsLeader(records.asInstanceOf[MemoryRecords],
@@ -101,18 +103,20 @@ final class KafkaMetadataLog private (
     )
   }
 
-  override def appendAsFollower(records: Records): LogAppendInfo = {
-    if (records.sizeInBytes == 0)
+  override def appendAsFollower(records: Records, epoch: Int): LogAppendInfo = 
{
+    if (records.sizeInBytes == 0) {
       throw new IllegalArgumentException("Attempt to append an empty record 
set")
+    }
 
-    
handleAndConvertLogAppendInfo(log.appendAsFollower(records.asInstanceOf[MemoryRecords]))
+    
handleAndConvertLogAppendInfo(log.appendAsFollower(records.asInstanceOf[MemoryRecords],
 epoch))
   }
 
   private def handleAndConvertLogAppendInfo(appendInfo: 
internals.log.LogAppendInfo): LogAppendInfo = {
-    if (appendInfo.firstOffset != JUnifiedLog.UNKNOWN_OFFSET)
+    if (appendInfo.firstOffset == JUnifiedLog.UNKNOWN_OFFSET) {
+      throw new CorruptRecordException(s"Append failed unexpectedly 
$appendInfo")
+    } else {
       new LogAppendInfo(appendInfo.firstOffset, appendInfo.lastOffset)
-    else
-      throw new KafkaException(s"Append failed unexpectedly")
+    }
   }
 
   override def lastFetchedEpoch: Int = {
diff --git a/core/src/main/scala/kafka/server/AbstractFetcherThread.scala 
b/core/src/main/scala/kafka/server/AbstractFetcherThread.scala
index be663d19ec8..7a98c83e7f4 100755
--- a/core/src/main/scala/kafka/server/AbstractFetcherThread.scala
+++ b/core/src/main/scala/kafka/server/AbstractFetcherThread.scala
@@ -78,9 +78,12 @@ abstract class AbstractFetcherThread(name: String,
   /* callbacks to be defined in subclass */
 
   // process fetched data
-  protected def processPartitionData(topicPartition: TopicPartition,
-                                     fetchOffset: Long,
-                                     partitionData: FetchData): 
Option[LogAppendInfo]
+  protected def processPartitionData(
+    topicPartition: TopicPartition,
+    fetchOffset: Long,
+    partitionLeaderEpoch: Int,
+    partitionData: FetchData
+  ): Option[LogAppendInfo]
 
   protected def truncate(topicPartition: TopicPartition, truncationState: 
OffsetTruncationState): Unit
 
@@ -333,7 +336,9 @@ abstract class AbstractFetcherThread(name: String,
             // In this case, we only want to process the fetch response if the 
partition state is ready for fetch and
             // the current offset is the same as the offset requested.
             val fetchPartitionData = sessionPartitions.get(topicPartition)
-            if (fetchPartitionData != null && fetchPartitionData.fetchOffset 
== currentFetchState.fetchOffset && currentFetchState.isReadyForFetch) {
+            if (fetchPartitionData != null &&
+                fetchPartitionData.fetchOffset == 
currentFetchState.fetchOffset &&
+                currentFetchState.isReadyForFetch) {
               Errors.forCode(partitionData.errorCode) match {
                 case Errors.NONE =>
                   try {
@@ -348,10 +353,16 @@ abstract class AbstractFetcherThread(name: String,
                         .setLeaderEpoch(partitionData.divergingEpoch.epoch)
                         .setEndOffset(partitionData.divergingEpoch.endOffset)
                     } else {
-                      // Once we hand off the partition data to the subclass, 
we can't mess with it any more in this thread
+                      /* Once we hand off the partition data to the subclass, 
we can't mess with it any more in this thread
+                       *
+                       * When appending batches to the log only append record 
batches up to the leader epoch when the FETCH
+                       * request was handled. This is done to make sure that 
logs are not inconsistent because of log
+                       * truncation and append after the FETCH request was 
handled. See KAFKA-18723 for more details.
+                       */
                       val logAppendInfoOpt = processPartitionData(
                         topicPartition,
                         currentFetchState.fetchOffset,
+                        
fetchPartitionData.currentLeaderEpoch.orElse(currentFetchState.currentLeaderEpoch),
                         partitionData
                       )
 
diff --git a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala 
b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
index 56492de3485..5f5373b3641 100644
--- a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
+++ b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
@@ -66,9 +66,12 @@ class ReplicaAlterLogDirsThread(name: String,
   }
 
   // process fetched data
-  override def processPartitionData(topicPartition: TopicPartition,
-                                    fetchOffset: Long,
-                                    partitionData: FetchData): 
Option[LogAppendInfo] = {
+  override def processPartitionData(
+    topicPartition: TopicPartition,
+    fetchOffset: Long,
+    partitionLeaderEpoch: Int,
+    partitionData: FetchData
+  ): Option[LogAppendInfo] = {
     val partition = replicaMgr.getPartitionOrException(topicPartition)
     val futureLog = partition.futureLocalLogOrException
     val records = toMemoryRecords(FetchResponse.recordsOrFail(partitionData))
@@ -78,7 +81,7 @@ class ReplicaAlterLogDirsThread(name: String,
         topicPartition, fetchOffset, futureLog.logEndOffset))
 
     val logAppendInfo = if (records.sizeInBytes() > 0)
-      partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = 
true)
+      partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = 
true, partitionLeaderEpoch)
     else
       None
 
diff --git a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala 
b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
index 7f0c6d41dbd..4c11301c567 100644
--- a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
+++ b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
@@ -98,9 +98,12 @@ class ReplicaFetcherThread(name: String,
   }
 
   // process fetched data
-  override def processPartitionData(topicPartition: TopicPartition,
-                                    fetchOffset: Long,
-                                    partitionData: FetchData): 
Option[LogAppendInfo] = {
+  override def processPartitionData(
+    topicPartition: TopicPartition,
+    fetchOffset: Long,
+    partitionLeaderEpoch: Int,
+    partitionData: FetchData
+  ): Option[LogAppendInfo] = {
     val logTrace = isTraceEnabled
     val partition = replicaMgr.getPartitionOrException(topicPartition)
     val log = partition.localLogOrException
@@ -117,7 +120,7 @@ class ReplicaFetcherThread(name: String,
         .format(log.logEndOffset, topicPartition, records.sizeInBytes, 
partitionData.highWatermark))
 
     // Append the leader's messages to the log
-    val logAppendInfo = 
partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = false)
+    val logAppendInfo = 
partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = false, 
partitionLeaderEpoch)
 
     if (logTrace)
       trace("Follower has replica log end offset %d after appending %d bytes 
of messages for partition %s"
diff --git a/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala 
b/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala
index 5f1752a54a6..263c35f5bfa 100644
--- a/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala
+++ b/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala
@@ -19,9 +19,12 @@ package kafka.raft
 import kafka.server.{KafkaConfig, KafkaRaftServer}
 import kafka.utils.TestUtils
 import org.apache.kafka.common.compress.Compression
+import org.apache.kafka.common.errors.CorruptRecordException
 import org.apache.kafka.common.errors.{InvalidConfigurationException, 
RecordTooLargeException}
 import org.apache.kafka.common.protocol
 import org.apache.kafka.common.protocol.{ObjectSerializationCache, Writable}
+import org.apache.kafka.common.record.ArbitraryMemoryRecords
+import org.apache.kafka.common.record.InvalidMemoryRecordsProvider
 import org.apache.kafka.common.record.{MemoryRecords, SimpleRecord}
 import org.apache.kafka.common.utils.Utils
 import org.apache.kafka.raft._
@@ -33,7 +36,14 @@ import org.apache.kafka.snapshot.{FileRawSnapshotWriter, 
RawSnapshotReader, RawS
 import org.apache.kafka.storage.internals.log.{LogConfig, 
LogStartOffsetIncrementReason, UnifiedLog}
 import org.apache.kafka.test.TestUtils.assertOptional
 import org.junit.jupiter.api.Assertions._
+import org.junit.jupiter.api.function.Executable
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
+import org.junit.jupiter.params.ParameterizedTest
+import org.junit.jupiter.params.provider.ArgumentsSource
+
+import net.jqwik.api.AfterFailureMode
+import net.jqwik.api.ForAll
+import net.jqwik.api.Property
 
 import java.io.File
 import java.nio.ByteBuffer
@@ -108,12 +118,93 @@ final class KafkaMetadataLogTest {
       classOf[RuntimeException],
       () => {
         log.appendAsFollower(
-          MemoryRecords.withRecords(initialOffset, Compression.NONE, 
currentEpoch, recordFoo)
+          MemoryRecords.withRecords(initialOffset, Compression.NONE, 
currentEpoch, recordFoo),
+          currentEpoch
         )
       }
     )
   }
 
+  @Test
+  def testEmptyAppendNotAllowed(): Unit = {
+    val log = buildMetadataLog(tempDir, mockTime)
+
+    assertThrows(classOf[IllegalArgumentException], () => 
log.appendAsFollower(MemoryRecords.EMPTY, 1));
+    assertThrows(classOf[IllegalArgumentException], () => 
log.appendAsLeader(MemoryRecords.EMPTY, 1));
+  }
+
+  @ParameterizedTest
+  @ArgumentsSource(classOf[InvalidMemoryRecordsProvider])
+  def testInvalidMemoryRecords(records: MemoryRecords, expectedException: 
Optional[Class[Exception]]): Unit = {
+    val log = buildMetadataLog(tempDir, mockTime)
+    val previousEndOffset = log.endOffset().offset()
+
+    val action: Executable = () => log.appendAsFollower(records, Int.MaxValue)
+    if (expectedException.isPresent()) {
+      assertThrows(expectedException.get, action)
+    } else {
+      assertThrows(classOf[CorruptRecordException], action)
+    }
+
+    assertEquals(previousEndOffset, log.endOffset().offset())
+  }
+
+  @Property(tries = 100, afterFailure = AfterFailureMode.SAMPLE_ONLY)
+  def testRandomRecords(
+    @ForAll(supplier = classOf[ArbitraryMemoryRecords]) records: MemoryRecords
+  ): Unit = {
+    val tempDir = TestUtils.tempDir()
+    try {
+      val log = buildMetadataLog(tempDir, mockTime)
+      val previousEndOffset = log.endOffset().offset()
+
+      assertThrows(
+        classOf[CorruptRecordException],
+        () => log.appendAsFollower(records, Int.MaxValue)
+      )
+
+      assertEquals(previousEndOffset, log.endOffset().offset())
+    } finally {
+      Utils.delete(tempDir)
+    }
+  }
+
+  @Test
+  def testInvalidLeaderEpoch(): Unit = {
+    val log = buildMetadataLog(tempDir, mockTime)
+    val previousEndOffset = log.endOffset().offset()
+    val epoch = log.lastFetchedEpoch() + 1
+    val numberOfRecords = 10
+
+    val batchWithValidEpoch = MemoryRecords.withRecords(
+      previousEndOffset,
+      Compression.NONE,
+      epoch,
+      (0 until numberOfRecords).map(number => new 
SimpleRecord(number.toString.getBytes)): _*
+    )
+
+    val batchWithInvalidEpoch = MemoryRecords.withRecords(
+      previousEndOffset + numberOfRecords,
+      Compression.NONE,
+      epoch + 1,
+      (0 until numberOfRecords).map(number => new 
SimpleRecord(number.toString.getBytes)): _*
+    )
+
+    val buffer = ByteBuffer.allocate(batchWithValidEpoch.sizeInBytes() + 
batchWithInvalidEpoch.sizeInBytes())
+    buffer.put(batchWithValidEpoch.buffer())
+    buffer.put(batchWithInvalidEpoch.buffer())
+    buffer.flip()
+
+    val records = MemoryRecords.readableRecords(buffer)
+
+    log.appendAsFollower(records, epoch)
+
+    // Check that only the first batch was appended
+    assertEquals(previousEndOffset + numberOfRecords, log.endOffset().offset())
+    // Check that the last fetched epoch matches the first batch
+    assertEquals(epoch, log.lastFetchedEpoch())
+  }
+
   @Test
   def testCreateSnapshot(): Unit = {
     val numberOfRecords = 10
@@ -1061,4 +1152,4 @@ object KafkaMetadataLogTest {
     }
     dir
   }
-}
\ No newline at end of file
+}
diff --git a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala 
b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
index 9aaaa4a64da..1bb32f1d6c2 100644
--- a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
+++ b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
@@ -428,6 +428,7 @@ class PartitionTest extends AbstractPartitionTest {
   def testMakeFollowerWithWithFollowerAppendRecords(): Unit = {
     val appendSemaphore = new Semaphore(0)
     val mockTime = new MockTime()
+    val prevLeaderEpoch = 0
 
     partition = new Partition(
       topicPartition,
@@ -480,24 +481,38 @@ class PartitionTest extends AbstractPartitionTest {
     }
 
     partition.createLogIfNotExists(isNew = true, isFutureReplica = false, 
offsetCheckpoints, None)
+    var partitionState = new LeaderAndIsrRequest.PartitionState()
+      .setControllerEpoch(0)
+      .setLeader(2)
+      .setLeaderEpoch(prevLeaderEpoch)
+      .setIsr(List[Integer](0, 1, 2, brokerId).asJava)
+      .setPartitionEpoch(1)
+      .setReplicas(List[Integer](0, 1, 2, brokerId).asJava)
+      .setIsNew(false)
+    assertTrue(partition.makeFollower(partitionState, offsetCheckpoints, None))
 
     val appendThread = new Thread {
       override def run(): Unit = {
-        val records = createRecords(List(new SimpleRecord("k1".getBytes, 
"v1".getBytes),
-          new SimpleRecord("k2".getBytes, "v2".getBytes)),
-          baseOffset = 0)
-        partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = 
false)
+        val records = createRecords(
+          List(
+            new SimpleRecord("k1".getBytes, "v1".getBytes),
+            new SimpleRecord("k2".getBytes, "v2".getBytes)
+          ),
+          baseOffset = 0,
+          partitionLeaderEpoch = prevLeaderEpoch
+        )
+        partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = 
false, prevLeaderEpoch)
       }
     }
     appendThread.start()
     TestUtils.waitUntilTrue(() => appendSemaphore.hasQueuedThreads, "follower 
log append is not called.")
 
-    val partitionState = new LeaderAndIsrRequest.PartitionState()
+    partitionState = new LeaderAndIsrRequest.PartitionState()
       .setControllerEpoch(0)
       .setLeader(2)
-      .setLeaderEpoch(1)
+      .setLeaderEpoch(prevLeaderEpoch + 1)
       .setIsr(List[Integer](0, 1, 2, brokerId).asJava)
-      .setPartitionEpoch(1)
+      .setPartitionEpoch(2)
       .setReplicas(List[Integer](0, 1, 2, brokerId).asJava)
       .setIsNew(false)
     assertTrue(partition.makeFollower(partitionState, offsetCheckpoints, None))
@@ -537,15 +552,22 @@ class PartitionTest extends AbstractPartitionTest {
     // Write to the future replica as if the log had been compacted, and do 
not roll the segment
 
     val buffer = ByteBuffer.allocate(1024)
-    val builder = MemoryRecords.builder(buffer, 
RecordBatch.CURRENT_MAGIC_VALUE, Compression.NONE,
-      TimestampType.CREATE_TIME, 0L, RecordBatch.NO_TIMESTAMP, 0)
+    val builder = MemoryRecords.builder(
+      buffer,
+      RecordBatch.CURRENT_MAGIC_VALUE,
+      Compression.NONE,
+      TimestampType.CREATE_TIME,
+      0L, // baseOffset
+      RecordBatch.NO_TIMESTAMP,
+      0 // partitionLeaderEpoch
+    )
     builder.appendWithOffset(2L, new SimpleRecord("k1".getBytes, 
"v3".getBytes))
     builder.appendWithOffset(5L, new SimpleRecord("k2".getBytes, 
"v6".getBytes))
     builder.appendWithOffset(6L, new SimpleRecord("k3".getBytes, 
"v7".getBytes))
     builder.appendWithOffset(7L, new SimpleRecord("k4".getBytes, 
"v8".getBytes))
 
     val futureLog = partition.futureLocalLogOrException
-    futureLog.appendAsFollower(builder.build())
+    futureLog.appendAsFollower(builder.build(), 0)
 
     assertTrue(partition.maybeReplaceCurrentWithFutureReplica())
   }
@@ -955,6 +977,18 @@ class PartitionTest extends AbstractPartitionTest {
   def testAppendRecordsAsFollowerBelowLogStartOffset(): Unit = {
     partition.createLogIfNotExists(isNew = false, isFutureReplica = false, 
offsetCheckpoints, None)
     val log = partition.localLogOrException
+    val epoch = 1
+
+    // Start off as follower
+    val partitionState = new LeaderAndIsrRequest.PartitionState()
+      .setControllerEpoch(0)
+      .setLeader(1)
+      .setLeaderEpoch(epoch)
+      .setIsr(List[Integer](0, 1, 2, brokerId).asJava)
+      .setPartitionEpoch(1)
+      .setReplicas(List[Integer](0, 1, 2, brokerId).asJava)
+      .setIsNew(false)
+    partition.makeFollower(partitionState, offsetCheckpoints, None)
 
     val initialLogStartOffset = 5L
     partition.truncateFullyAndStartAt(initialLogStartOffset, isFuture = false)
@@ -964,9 +998,14 @@ class PartitionTest extends AbstractPartitionTest {
       s"Log start offset after truncate fully and start at 
$initialLogStartOffset:")
 
     // verify that we cannot append records that do not contain log start 
offset even if the log is empty
-    assertThrows(classOf[UnexpectedAppendOffsetException], () =>
+    assertThrows(
+      classOf[UnexpectedAppendOffsetException],
       // append one record with offset = 3
-      partition.appendRecordsToFollowerOrFutureReplica(createRecords(List(new 
SimpleRecord("k1".getBytes, "v1".getBytes)), baseOffset = 3L), isFuture = false)
+      () => partition.appendRecordsToFollowerOrFutureReplica(
+        createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes)), 
baseOffset = 3L),
+        isFuture = false,
+        partitionLeaderEpoch = epoch
+      )
     )
     assertEquals(initialLogStartOffset, log.logEndOffset,
       s"Log end offset should not change after failure to append")
@@ -978,12 +1017,16 @@ class PartitionTest extends AbstractPartitionTest {
                                      new SimpleRecord("k2".getBytes, 
"v2".getBytes),
                                      new SimpleRecord("k3".getBytes, 
"v3".getBytes)),
                                 baseOffset = newLogStartOffset)
-    partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = false)
+    partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = 
false, partitionLeaderEpoch = epoch)
     assertEquals(7L, log.logEndOffset, s"Log end offset after append of 3 
records with base offset $newLogStartOffset:")
     assertEquals(newLogStartOffset, log.logStartOffset, s"Log start offset 
after append of 3 records with base offset $newLogStartOffset:")
 
     // and we can append more records after that
-    partition.appendRecordsToFollowerOrFutureReplica(createRecords(List(new 
SimpleRecord("k1".getBytes, "v1".getBytes)), baseOffset = 7L), isFuture = false)
+    partition.appendRecordsToFollowerOrFutureReplica(
+      createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes)), 
baseOffset = 7L),
+      isFuture = false,
+      partitionLeaderEpoch = epoch
+    )
     assertEquals(8L, log.logEndOffset, s"Log end offset after append of 1 
record at offset 7:")
     assertEquals(newLogStartOffset, log.logStartOffset, s"Log start offset not 
expected to change:")
 
@@ -991,11 +1034,18 @@ class PartitionTest extends AbstractPartitionTest {
     val records2 = createRecords(List(new SimpleRecord("k1".getBytes, 
"v1".getBytes),
       new SimpleRecord("k2".getBytes, "v2".getBytes)),
       baseOffset = 3L)
-    assertThrows(classOf[UnexpectedAppendOffsetException], () => 
partition.appendRecordsToFollowerOrFutureReplica(records2, isFuture = false))
+    assertThrows(
+      classOf[UnexpectedAppendOffsetException],
+      () => partition.appendRecordsToFollowerOrFutureReplica(records2, 
isFuture = false, partitionLeaderEpoch = epoch)
+    )
     assertEquals(8L, log.logEndOffset, s"Log end offset should not change 
after failure to append")
 
     // we still can append to next offset
-    partition.appendRecordsToFollowerOrFutureReplica(createRecords(List(new 
SimpleRecord("k1".getBytes, "v1".getBytes)), baseOffset = 8L), isFuture = false)
+    partition.appendRecordsToFollowerOrFutureReplica(
+      createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes)), 
baseOffset = 8L),
+      isFuture = false,
+      partitionLeaderEpoch = epoch
+    )
     assertEquals(9L, log.logEndOffset, s"Log end offset after append of 1 
record at offset 8:")
     assertEquals(newLogStartOffset, log.logStartOffset, s"Log start offset not 
expected to change:")
   }
@@ -1078,9 +1128,13 @@ class PartitionTest extends AbstractPartitionTest {
 
   @Test
   def testAppendRecordsToFollowerWithNoReplicaThrowsException(): Unit = {
-    assertThrows(classOf[NotLeaderOrFollowerException], () =>
-      partition.appendRecordsToFollowerOrFutureReplica(
-           createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes)), 
baseOffset = 0L), isFuture = false)
+    assertThrows(
+      classOf[NotLeaderOrFollowerException],
+      () => partition.appendRecordsToFollowerOrFutureReplica(
+        createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes)), 
baseOffset = 0L),
+        isFuture = false,
+        partitionLeaderEpoch = 0
+      )
     )
   }
 
@@ -3457,12 +3511,13 @@ class PartitionTest extends AbstractPartitionTest {
 
     val replicas = Seq(brokerId, brokerId + 1)
     val isr = replicas
+    val epoch = 0
     addBrokerEpochToMockMetadataCache(metadataCache, replicas.toList)
     partition.makeLeader(
       new LeaderAndIsrRequest.PartitionState()
         .setControllerEpoch(0)
         .setLeader(brokerId)
-        .setLeaderEpoch(0)
+        .setLeaderEpoch(epoch)
         .setIsr(isr.map(Int.box).asJava)
         .setReplicas(replicas.map(Int.box).asJava)
         .setPartitionEpoch(1)
@@ -3495,7 +3550,8 @@ class PartitionTest extends AbstractPartitionTest {
 
     partition.appendRecordsToFollowerOrFutureReplica(
       records = records,
-      isFuture = true
+      isFuture = true,
+      partitionLeaderEpoch = epoch
     )
 
     listener.verify()
@@ -3640,9 +3696,9 @@ class PartitionTest extends AbstractPartitionTest {
     producerStateManager,
     _topicId = topicId) {
 
-    override def appendAsFollower(records: MemoryRecords): LogAppendInfo = {
+    override def appendAsFollower(records: MemoryRecords, epoch: Int): 
LogAppendInfo = {
       appendSemaphore.acquire()
-      val appendInfo = super.appendAsFollower(records)
+      val appendInfo = super.appendAsFollower(records, epoch)
       appendInfo
     }
   }
diff --git a/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala 
b/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
index 895d5d64363..716731b48d3 100644
--- a/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
@@ -1457,7 +1457,7 @@ class LogCleanerTest extends Logging {
       log.appendAsLeader(TestUtils.singletonRecords(value = v, key = k), 
leaderEpoch = 0)
       //0 to Int.MaxValue is Int.MaxValue+1 message, -1 will be the last 
message of i-th segment
       val records = messageWithOffset(k, v, (i + 1L) * (Int.MaxValue + 1L) -1 )
-      log.appendAsFollower(records)
+      log.appendAsFollower(records, Int.MaxValue)
       assertEquals(i + 1, log.numberOfSegments)
     }
 
@@ -1511,7 +1511,7 @@ class LogCleanerTest extends Logging {
 
     // forward offset and append message to next segment at offset Int.MaxValue
     val records = messageWithOffset("hello".getBytes, "hello".getBytes, 
Int.MaxValue - 1)
-    log.appendAsFollower(records)
+    log.appendAsFollower(records, Int.MaxValue)
     log.appendAsLeader(TestUtils.singletonRecords(value = "hello".getBytes, 
key = "hello".getBytes), leaderEpoch = 0)
     assertEquals(Int.MaxValue, log.activeSegment.offsetIndex.lastOffset)
 
@@ -1560,14 +1560,14 @@ class LogCleanerTest extends Logging {
     val log = makeLog(config = LogConfig.fromProps(logConfig.originals, 
logProps))
 
     val record1 = messageWithOffset("hello".getBytes, "hello".getBytes, 0)
-    log.appendAsFollower(record1)
+    log.appendAsFollower(record1, Int.MaxValue)
     val record2 = messageWithOffset("hello".getBytes, "hello".getBytes, 1)
-    log.appendAsFollower(record2)
+    log.appendAsFollower(record2, Int.MaxValue)
     log.roll(Some(Int.MaxValue/2)) // starting a new log segment at offset 
Int.MaxValue/2
     val record3 = messageWithOffset("hello".getBytes, "hello".getBytes, 
Int.MaxValue/2)
-    log.appendAsFollower(record3)
+    log.appendAsFollower(record3, Int.MaxValue)
     val record4 = messageWithOffset("hello".getBytes, "hello".getBytes, 
Int.MaxValue.toLong + 1)
-    log.appendAsFollower(record4)
+    log.appendAsFollower(record4, Int.MaxValue)
 
     assertTrue(log.logEndOffset - 1 - log.logStartOffset > Int.MaxValue, 
"Actual offset range should be > Int.MaxValue")
     assertTrue(log.logSegments.asScala.last.offsetIndex.lastOffset - 
log.logStartOffset <= Int.MaxValue,
@@ -1881,8 +1881,8 @@ class LogCleanerTest extends Logging {
     val noDupSetOffset = 50
     val noDupSet = noDupSetKeys zip (noDupSetOffset until noDupSetOffset + 
noDupSetKeys.size)
 
-    log.appendAsFollower(invalidCleanedMessage(dupSetOffset, dupSet, codec))
-    log.appendAsFollower(invalidCleanedMessage(noDupSetOffset, noDupSet, 
codec))
+    log.appendAsFollower(invalidCleanedMessage(dupSetOffset, dupSet, codec), 
Int.MaxValue)
+    log.appendAsFollower(invalidCleanedMessage(noDupSetOffset, noDupSet, 
codec), Int.MaxValue)
 
     log.roll()
 
@@ -1968,7 +1968,7 @@ class LogCleanerTest extends Logging {
       log.roll(Some(11L))
 
       // active segment record
-      log.appendAsFollower(messageWithOffset(1015, 1015, 11L))
+      log.appendAsFollower(messageWithOffset(1015, 1015, 11L), Int.MaxValue)
 
       val (nextDirtyOffset, _) = cleaner.clean(LogToClean(log.topicPartition, 
log, 0L, log.activeSegment.baseOffset, needCompactionNow = true))
       assertEquals(log.activeSegment.baseOffset, nextDirtyOffset,
@@ -1987,7 +1987,7 @@ class LogCleanerTest extends Logging {
       log.roll(Some(30L))
 
       // active segment record
-      log.appendAsFollower(messageWithOffset(1015, 1015, 30L))
+      log.appendAsFollower(messageWithOffset(1015, 1015, 30L), Int.MaxValue)
 
       val (nextDirtyOffset, _) = cleaner.clean(LogToClean(log.topicPartition, 
log, 0L, log.activeSegment.baseOffset, needCompactionNow = true))
       assertEquals(log.activeSegment.baseOffset, nextDirtyOffset,
@@ -2204,7 +2204,7 @@ class LogCleanerTest extends Logging {
 
   private def writeToLog(log: UnifiedLog, keysAndValues: Iterable[(Int, Int)], 
offsetSeq: Iterable[Long]): Iterable[Long] = {
     for (((key, value), offset) <- keysAndValues.zip(offsetSeq))
-      yield log.appendAsFollower(messageWithOffset(key, value, 
offset)).lastOffset
+      yield log.appendAsFollower(messageWithOffset(key, value, offset), 
Int.MaxValue).lastOffset
   }
 
   private def invalidCleanedMessage(initialOffset: Long,
diff --git a/core/src/test/scala/unit/kafka/log/LogConcurrencyTest.scala 
b/core/src/test/scala/unit/kafka/log/LogConcurrencyTest.scala
index 342ef145b6d..72ad7a718d1 100644
--- a/core/src/test/scala/unit/kafka/log/LogConcurrencyTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogConcurrencyTest.scala
@@ -126,9 +126,14 @@ class LogConcurrencyTest {
               log.appendAsLeader(TestUtils.records(records), leaderEpoch)
               log.maybeIncrementHighWatermark(logEndOffsetMetadata)
             } else {
-              log.appendAsFollower(TestUtils.records(records,
-                baseOffset = logEndOffset,
-                partitionLeaderEpoch = leaderEpoch))
+              log.appendAsFollower(
+                TestUtils.records(
+                  records,
+                  baseOffset = logEndOffset,
+                  partitionLeaderEpoch = leaderEpoch
+                ),
+                Int.MaxValue
+              )
               log.updateHighWatermark(logEndOffset)
             }
 
diff --git a/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala 
b/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
index 3bdf8a9436c..11c2b620058 100644
--- a/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
@@ -923,17 +923,17 @@ class LogLoaderTest {
     val set3 = MemoryRecords.withRecords(Integer.MAX_VALUE.toLong + 3, 
Compression.NONE, 0, new SimpleRecord("v4".getBytes(), "k4".getBytes()))
     val set4 = MemoryRecords.withRecords(Integer.MAX_VALUE.toLong + 4, 
Compression.NONE, 0, new SimpleRecord("v5".getBytes(), "k5".getBytes()))
     //Writes into an empty log with baseOffset 0
-    log.appendAsFollower(set1)
+    log.appendAsFollower(set1, Int.MaxValue)
     assertEquals(0L, log.activeSegment.baseOffset)
     //This write will roll the segment, yielding a new segment with base 
offset = max(1, Integer.MAX_VALUE+2) = Integer.MAX_VALUE+2
-    log.appendAsFollower(set2)
+    log.appendAsFollower(set2, Int.MaxValue)
     assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset)
     assertTrue(LogFileUtils.producerSnapshotFile(logDir, 
Integer.MAX_VALUE.toLong + 2).exists)
     //This will go into the existing log
-    log.appendAsFollower(set3)
+    log.appendAsFollower(set3, Int.MaxValue)
     assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset)
     //This will go into the existing log
-    log.appendAsFollower(set4)
+    log.appendAsFollower(set4, Int.MaxValue)
     assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset)
     log.close()
     val indexFiles = logDir.listFiles.filter(file => 
file.getName.contains(".index"))
@@ -962,17 +962,17 @@ class LogLoaderTest {
       new SimpleRecord("v7".getBytes(), "k7".getBytes()),
       new SimpleRecord("v8".getBytes(), "k8".getBytes()))
     //Writes into an empty log with baseOffset 0
-    log.appendAsFollower(set1)
+    log.appendAsFollower(set1, Int.MaxValue)
     assertEquals(0L, log.activeSegment.baseOffset)
     //This write will roll the segment, yielding a new segment with base 
offset = max(1, Integer.MAX_VALUE+2) = Integer.MAX_VALUE+2
-    log.appendAsFollower(set2)
+    log.appendAsFollower(set2, Int.MaxValue)
     assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset)
     assertTrue(LogFileUtils.producerSnapshotFile(logDir, 
Integer.MAX_VALUE.toLong + 2).exists)
     //This will go into the existing log
-    log.appendAsFollower(set3)
+    log.appendAsFollower(set3, Int.MaxValue)
     assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset)
     //This will go into the existing log
-    log.appendAsFollower(set4)
+    log.appendAsFollower(set4, Int.MaxValue)
     assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset)
     log.close()
     val indexFiles = logDir.listFiles.filter(file => 
file.getName.contains(".index"))
@@ -1002,18 +1002,18 @@ class LogLoaderTest {
       new SimpleRecord("v7".getBytes(), "k7".getBytes()),
       new SimpleRecord("v8".getBytes(), "k8".getBytes()))
     //Writes into an empty log with baseOffset 0
-    log.appendAsFollower(set1)
+    log.appendAsFollower(set1, Int.MaxValue)
     assertEquals(0L, log.activeSegment.baseOffset)
     //This write will roll the segment, yielding a new segment with base 
offset = max(1, 3) = 3
-    log.appendAsFollower(set2)
+    log.appendAsFollower(set2, Int.MaxValue)
     assertEquals(3, log.activeSegment.baseOffset)
     assertTrue(LogFileUtils.producerSnapshotFile(logDir, 3).exists)
     //This will also roll the segment, yielding a new segment with base offset 
= max(5, Integer.MAX_VALUE+4) = Integer.MAX_VALUE+4
-    log.appendAsFollower(set3)
+    log.appendAsFollower(set3, Int.MaxValue)
     assertEquals(Integer.MAX_VALUE.toLong + 4, log.activeSegment.baseOffset)
     assertTrue(LogFileUtils.producerSnapshotFile(logDir, 
Integer.MAX_VALUE.toLong + 4).exists)
     //This will go into the existing log
-    log.appendAsFollower(set4)
+    log.appendAsFollower(set4, Int.MaxValue)
     assertEquals(Integer.MAX_VALUE.toLong + 4, log.activeSegment.baseOffset)
     log.close()
     val indexFiles = logDir.listFiles.filter(file => 
file.getName.contains(".index"))
@@ -1203,16 +1203,16 @@ class LogLoaderTest {
     val log = createLog(logDir, new LogConfig(new Properties))
     val leaderEpochCache = log.leaderEpochCache
     val firstBatch = singletonRecordsWithLeaderEpoch(value = 
"random".getBytes, leaderEpoch = 1, offset = 0)
-    log.appendAsFollower(records = firstBatch)
+    log.appendAsFollower(records = firstBatch, Int.MaxValue)
 
     val secondBatch = singletonRecordsWithLeaderEpoch(value = 
"random".getBytes, leaderEpoch = 2, offset = 1)
-    log.appendAsFollower(records = secondBatch)
+    log.appendAsFollower(records = secondBatch, Int.MaxValue)
 
     val thirdBatch = singletonRecordsWithLeaderEpoch(value = 
"random".getBytes, leaderEpoch = 2, offset = 2)
-    log.appendAsFollower(records = thirdBatch)
+    log.appendAsFollower(records = thirdBatch, Int.MaxValue)
 
     val fourthBatch = singletonRecordsWithLeaderEpoch(value = 
"random".getBytes, leaderEpoch = 3, offset = 3)
-    log.appendAsFollower(records = fourthBatch)
+    log.appendAsFollower(records = fourthBatch, Int.MaxValue)
 
     assertEquals(java.util.Arrays.asList(new EpochEntry(1, 0), new 
EpochEntry(2, 1), new EpochEntry(3, 3)), leaderEpochCache.epochEntries)
 
diff --git a/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala 
b/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
index 2ac71abd7be..f1e53014d72 100755
--- a/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
+++ b/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
@@ -48,11 +48,16 @@ import 
org.apache.kafka.storage.log.metrics.{BrokerTopicMetrics, BrokerTopicStat
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
 import org.junit.jupiter.params.ParameterizedTest
+import org.junit.jupiter.params.provider.ArgumentsSource
 import org.junit.jupiter.params.provider.{EnumSource, ValueSource}
 import org.mockito.ArgumentMatchers
 import org.mockito.ArgumentMatchers.{any, anyLong}
 import org.mockito.Mockito.{doAnswer, doThrow, spy}
 
+import net.jqwik.api.AfterFailureMode
+import net.jqwik.api.ForAll
+import net.jqwik.api.Property
+
 import java.io._
 import java.nio.ByteBuffer
 import java.nio.file.Files
@@ -304,7 +309,7 @@ class UnifiedLogTest {
     assertHighWatermark(3L)
 
     // Update high watermark as follower
-    log.appendAsFollower(records(3L))
+    log.appendAsFollower(records(3L), leaderEpoch)
     log.updateHighWatermark(6L)
     assertHighWatermark(6L)
 
@@ -582,6 +587,7 @@ class UnifiedLogTest {
   @Test
   def testRollSegmentThatAlreadyExists(): Unit = {
     val logConfig = LogTestUtils.createLogConfig(segmentMs = 1 * 60 * 60L)
+    val partitionLeaderEpoch = 0
 
     // create a log
     val log = createLog(logDir, logConfig)
@@ -594,16 +600,16 @@ class UnifiedLogTest {
     // should be able to append records to active segment
     val records = TestUtils.records(
       List(new SimpleRecord(mockTime.milliseconds, "k1".getBytes, 
"v1".getBytes)),
-      baseOffset = 0L, partitionLeaderEpoch = 0)
-    log.appendAsFollower(records)
+      baseOffset = 0L, partitionLeaderEpoch = partitionLeaderEpoch)
+    log.appendAsFollower(records, partitionLeaderEpoch)
     assertEquals(1, log.numberOfSegments, "Expect one segment.")
     assertEquals(0L, log.activeSegment.baseOffset)
 
     // make sure we can append more records
     val records2 = TestUtils.records(
       List(new SimpleRecord(mockTime.milliseconds + 10, "k2".getBytes, 
"v2".getBytes)),
-      baseOffset = 1L, partitionLeaderEpoch = 0)
-    log.appendAsFollower(records2)
+      baseOffset = 1L, partitionLeaderEpoch = partitionLeaderEpoch)
+    log.appendAsFollower(records2, partitionLeaderEpoch)
 
     assertEquals(2, log.logEndOffset, "Expect two records in the log")
     assertEquals(0, LogTestUtils.readLog(log, 0, 
1).records.batches.iterator.next().lastOffset)
@@ -618,8 +624,8 @@ class UnifiedLogTest {
     log.activeSegment.offsetIndex.resize(0)
     val records3 = TestUtils.records(
       List(new SimpleRecord(mockTime.milliseconds + 12, "k3".getBytes, 
"v3".getBytes)),
-      baseOffset = 2L, partitionLeaderEpoch = 0)
-    log.appendAsFollower(records3)
+      baseOffset = 2L, partitionLeaderEpoch = partitionLeaderEpoch)
+    log.appendAsFollower(records3, partitionLeaderEpoch)
     assertTrue(log.activeSegment.offsetIndex.maxEntries > 1)
     assertEquals(2, LogTestUtils.readLog(log, 2, 
1).records.batches.iterator.next().lastOffset)
     assertEquals(2, log.numberOfSegments, "Expect two segments.")
@@ -793,17 +799,25 @@ class UnifiedLogTest {
     val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5)
     val log = createLog(logDir, logConfig)
     val pid = 1L
-    val epoch = 0.toShort
+    val producerEpoch = 0.toShort
+    val partitionLeaderEpoch = 0
     val seq = 0
     val baseOffset = 23L
 
     // create a batch with a couple gaps to simulate compaction
-    val records = TestUtils.records(producerId = pid, producerEpoch = epoch, 
sequence = seq, baseOffset = baseOffset, records = List(
-      new SimpleRecord(mockTime.milliseconds(), "a".getBytes),
-      new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "b".getBytes),
-      new SimpleRecord(mockTime.milliseconds(), "c".getBytes),
-      new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "d".getBytes)))
-    records.batches.forEach(_.setPartitionLeaderEpoch(0))
+    val records = TestUtils.records(
+      producerId = pid,
+      producerEpoch = producerEpoch,
+      sequence = seq,
+      baseOffset = baseOffset,
+      records = List(
+        new SimpleRecord(mockTime.milliseconds(), "a".getBytes),
+        new SimpleRecord(mockTime.milliseconds(), "key".getBytes, 
"b".getBytes),
+        new SimpleRecord(mockTime.milliseconds(), "c".getBytes),
+        new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "d".getBytes)
+      )
+    )
+    records.batches.forEach(_.setPartitionLeaderEpoch(partitionLeaderEpoch))
 
     val filtered = ByteBuffer.allocate(2048)
     records.filterTo(new RecordFilter(0, 0) {
@@ -814,14 +828,18 @@ class UnifiedLogTest {
     filtered.flip()
     val filteredRecords = MemoryRecords.readableRecords(filtered)
 
-    log.appendAsFollower(filteredRecords)
+    log.appendAsFollower(filteredRecords, partitionLeaderEpoch)
 
     // append some more data and then truncate to force rebuilding of the PID 
map
-    val moreRecords = TestUtils.records(baseOffset = baseOffset + 4, records = 
List(
-      new SimpleRecord(mockTime.milliseconds(), "e".getBytes),
-      new SimpleRecord(mockTime.milliseconds(), "f".getBytes)))
-    moreRecords.batches.forEach(_.setPartitionLeaderEpoch(0))
-    log.appendAsFollower(moreRecords)
+    val moreRecords = TestUtils.records(
+      baseOffset = baseOffset + 4,
+      records = List(
+        new SimpleRecord(mockTime.milliseconds(), "e".getBytes),
+        new SimpleRecord(mockTime.milliseconds(), "f".getBytes)
+      )
+    )
+    
moreRecords.batches.forEach(_.setPartitionLeaderEpoch(partitionLeaderEpoch))
+    log.appendAsFollower(moreRecords, partitionLeaderEpoch)
 
     log.truncateTo(baseOffset + 4)
 
@@ -837,15 +855,23 @@ class UnifiedLogTest {
     val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5)
     val log = createLog(logDir, logConfig)
     val pid = 1L
-    val epoch = 0.toShort
+    val producerEpoch = 0.toShort
+    val partitionLeaderEpoch = 0
     val seq = 0
     val baseOffset = 23L
 
     // create an empty batch
-    val records = TestUtils.records(producerId = pid, producerEpoch = epoch, 
sequence = seq, baseOffset = baseOffset, records = List(
-      new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "a".getBytes),
-      new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "b".getBytes)))
-    records.batches.forEach(_.setPartitionLeaderEpoch(0))
+    val records = TestUtils.records(
+      producerId = pid,
+      producerEpoch = producerEpoch,
+      sequence = seq,
+      baseOffset = baseOffset,
+      records = List(
+        new SimpleRecord(mockTime.milliseconds(), "key".getBytes, 
"a".getBytes),
+        new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "b".getBytes)
+      )
+    )
+    records.batches.forEach(_.setPartitionLeaderEpoch(partitionLeaderEpoch))
 
     val filtered = ByteBuffer.allocate(2048)
     records.filterTo(new RecordFilter(0, 0) {
@@ -856,14 +882,18 @@ class UnifiedLogTest {
     filtered.flip()
     val filteredRecords = MemoryRecords.readableRecords(filtered)
 
-    log.appendAsFollower(filteredRecords)
+    log.appendAsFollower(filteredRecords, partitionLeaderEpoch)
 
     // append some more data and then truncate to force rebuilding of the PID 
map
-    val moreRecords = TestUtils.records(baseOffset = baseOffset + 2, records = 
List(
-      new SimpleRecord(mockTime.milliseconds(), "e".getBytes),
-      new SimpleRecord(mockTime.milliseconds(), "f".getBytes)))
-    moreRecords.batches.forEach(_.setPartitionLeaderEpoch(0))
-    log.appendAsFollower(moreRecords)
+    val moreRecords = TestUtils.records(
+      baseOffset = baseOffset + 2,
+      records = List(
+        new SimpleRecord(mockTime.milliseconds(), "e".getBytes),
+        new SimpleRecord(mockTime.milliseconds(), "f".getBytes)
+      )
+    )
+    
moreRecords.batches.forEach(_.setPartitionLeaderEpoch(partitionLeaderEpoch))
+    log.appendAsFollower(moreRecords, partitionLeaderEpoch)
 
     log.truncateTo(baseOffset + 2)
 
@@ -879,17 +909,25 @@ class UnifiedLogTest {
     val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5)
     val log = createLog(logDir, logConfig)
     val pid = 1L
-    val epoch = 0.toShort
+    val producerEpoch = 0.toShort
+    val partitionLeaderEpoch = 0
     val seq = 0
     val baseOffset = 23L
 
     // create a batch with a couple gaps to simulate compaction
-    val records = TestUtils.records(producerId = pid, producerEpoch = epoch, 
sequence = seq, baseOffset = baseOffset, records = List(
-      new SimpleRecord(mockTime.milliseconds(), "a".getBytes),
-      new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "b".getBytes),
-      new SimpleRecord(mockTime.milliseconds(), "c".getBytes),
-      new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "d".getBytes)))
-    records.batches.forEach(_.setPartitionLeaderEpoch(0))
+    val records = TestUtils.records(
+      producerId = pid,
+      producerEpoch = producerEpoch,
+      sequence = seq,
+      baseOffset = baseOffset,
+      records = List(
+        new SimpleRecord(mockTime.milliseconds(), "a".getBytes),
+        new SimpleRecord(mockTime.milliseconds(), "key".getBytes, 
"b".getBytes),
+        new SimpleRecord(mockTime.milliseconds(), "c".getBytes),
+        new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "d".getBytes)
+      )
+    )
+    records.batches.forEach(_.setPartitionLeaderEpoch(partitionLeaderEpoch))
 
     val filtered = ByteBuffer.allocate(2048)
     records.filterTo(new RecordFilter(0, 0) {
@@ -900,7 +938,7 @@ class UnifiedLogTest {
     filtered.flip()
     val filteredRecords = MemoryRecords.readableRecords(filtered)
 
-    log.appendAsFollower(filteredRecords)
+    log.appendAsFollower(filteredRecords, partitionLeaderEpoch)
     val activeProducers = log.activeProducersWithLastSequence
     assertTrue(activeProducers.contains(pid))
 
@@ -1330,33 +1368,44 @@ class UnifiedLogTest {
     // create a log
     val log = createLog(logDir, new LogConfig(new Properties))
 
-    val epoch: Short = 0
+    val producerEpoch: Short = 0
+    val partitionLeaderEpoch = 0
     val buffer = ByteBuffer.allocate(512)
 
-    var builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, 
Compression.NONE,
-      TimestampType.LOG_APPEND_TIME, 0L, mockTime.milliseconds(), 1L, epoch, 
0, false, 0)
+    var builder = MemoryRecords.builder(
+      buffer, RecordBatch.MAGIC_VALUE_V2, Compression.NONE,
+      TimestampType.LOG_APPEND_TIME, 0L, mockTime.milliseconds(), 1L, 
producerEpoch, 0, false,
+      partitionLeaderEpoch
+    )
     builder.append(new SimpleRecord("key".getBytes, "value".getBytes))
     builder.close()
 
     builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, 
Compression.NONE,
-      TimestampType.LOG_APPEND_TIME, 1L, mockTime.milliseconds(), 2L, epoch, 
0, false, 0)
+      TimestampType.LOG_APPEND_TIME, 1L, mockTime.milliseconds(), 2L, 
producerEpoch, 0, false,
+      partitionLeaderEpoch)
     builder.append(new SimpleRecord("key".getBytes, "value".getBytes))
     builder.close()
 
-    builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, 
Compression.NONE,
-      TimestampType.LOG_APPEND_TIME, 2L, mockTime.milliseconds(), 3L, epoch, 
0, false, 0)
+    builder = MemoryRecords.builder(
+      buffer, RecordBatch.MAGIC_VALUE_V2, Compression.NONE,
+      TimestampType.LOG_APPEND_TIME, 2L, mockTime.milliseconds(), 3L, 
producerEpoch, 0, false,
+      partitionLeaderEpoch
+    )
     builder.append(new SimpleRecord("key".getBytes, "value".getBytes))
     builder.close()
 
-    builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, 
Compression.NONE,
-      TimestampType.LOG_APPEND_TIME, 3L, mockTime.milliseconds(), 4L, epoch, 
0, false, 0)
+    builder = MemoryRecords.builder(
+      buffer, RecordBatch.MAGIC_VALUE_V2, Compression.NONE,
+      TimestampType.LOG_APPEND_TIME, 3L, mockTime.milliseconds(), 4L, 
producerEpoch, 0, false,
+      partitionLeaderEpoch
+    )
     builder.append(new SimpleRecord("key".getBytes, "value".getBytes))
     builder.close()
 
     buffer.flip()
     val memoryRecords = MemoryRecords.readableRecords(buffer)
 
-    log.appendAsFollower(memoryRecords)
+    log.appendAsFollower(memoryRecords, partitionLeaderEpoch)
     log.flush(false)
 
     val fetchedData = LogTestUtils.readLog(log, 0, Int.MaxValue)
@@ -1375,7 +1424,7 @@ class UnifiedLogTest {
   def testDuplicateAppendToFollower(): Unit = {
     val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024 * 
5)
     val log = createLog(logDir, logConfig)
-    val epoch: Short = 0
+    val producerEpoch: Short = 0
     val pid = 1L
     val baseSequence = 0
     val partitionLeaderEpoch = 0
@@ -1383,10 +1432,32 @@ class UnifiedLogTest {
     // this is a bit contrived. to trigger the duplicate case for a follower 
append, we have to append
     // a batch with matching sequence numbers, but valid increasing offsets
     assertEquals(0L, log.logEndOffset)
-    log.appendAsFollower(MemoryRecords.withIdempotentRecords(0L, 
Compression.NONE, pid, epoch, baseSequence,
-      partitionLeaderEpoch, new SimpleRecord("a".getBytes), new 
SimpleRecord("b".getBytes)))
-    log.appendAsFollower(MemoryRecords.withIdempotentRecords(2L, 
Compression.NONE, pid, epoch, baseSequence,
-      partitionLeaderEpoch, new SimpleRecord("a".getBytes), new 
SimpleRecord("b".getBytes)))
+    log.appendAsFollower(
+      MemoryRecords.withIdempotentRecords(
+        0L,
+        Compression.NONE,
+        pid,
+        producerEpoch,
+        baseSequence,
+        partitionLeaderEpoch,
+        new SimpleRecord("a".getBytes),
+        new SimpleRecord("b".getBytes)
+      ),
+      partitionLeaderEpoch
+    )
+    log.appendAsFollower(
+      MemoryRecords.withIdempotentRecords(
+        2L,
+        Compression.NONE,
+        pid,
+        producerEpoch,
+        baseSequence,
+        partitionLeaderEpoch,
+        new SimpleRecord("a".getBytes),
+        new SimpleRecord("b".getBytes)
+      ),
+      partitionLeaderEpoch
+    )
 
     // Ensure that even the duplicate sequences are accepted on the follower.
     assertEquals(4L, log.logEndOffset)
@@ -1399,48 +1470,49 @@ class UnifiedLogTest {
 
     val pid1 = 1L
     val pid2 = 2L
-    val epoch: Short = 0
+    val producerEpoch: Short = 0
 
     val buffer = ByteBuffer.allocate(512)
 
     // pid1 seq = 0
     var builder = MemoryRecords.builder(buffer, 
RecordBatch.CURRENT_MAGIC_VALUE, Compression.NONE,
-      TimestampType.LOG_APPEND_TIME, 0L, mockTime.milliseconds(), pid1, epoch, 
0)
+      TimestampType.LOG_APPEND_TIME, 0L, mockTime.milliseconds(), pid1, 
producerEpoch, 0)
     builder.append(new SimpleRecord("key".getBytes, "value".getBytes))
     builder.close()
 
     // pid2 seq = 0
     builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, 
Compression.NONE,
-      TimestampType.LOG_APPEND_TIME, 1L, mockTime.milliseconds(), pid2, epoch, 
0)
+      TimestampType.LOG_APPEND_TIME, 1L, mockTime.milliseconds(), pid2, 
producerEpoch, 0)
     builder.append(new SimpleRecord("key".getBytes, "value".getBytes))
     builder.close()
 
     // pid1 seq = 1
     builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, 
Compression.NONE,
-      TimestampType.LOG_APPEND_TIME, 2L, mockTime.milliseconds(), pid1, epoch, 
1)
+      TimestampType.LOG_APPEND_TIME, 2L, mockTime.milliseconds(), pid1, 
producerEpoch, 1)
     builder.append(new SimpleRecord("key".getBytes, "value".getBytes))
     builder.close()
 
     // pid2 seq = 1
     builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, 
Compression.NONE,
-      TimestampType.LOG_APPEND_TIME, 3L, mockTime.milliseconds(), pid2, epoch, 
1)
+      TimestampType.LOG_APPEND_TIME, 3L, mockTime.milliseconds(), pid2, 
producerEpoch, 1)
     builder.append(new SimpleRecord("key".getBytes, "value".getBytes))
     builder.close()
 
     // // pid1 seq = 1 (duplicate)
     builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, 
Compression.NONE,
-      TimestampType.LOG_APPEND_TIME, 4L, mockTime.milliseconds(), pid1, epoch, 
1)
+      TimestampType.LOG_APPEND_TIME, 4L, mockTime.milliseconds(), pid1, 
producerEpoch, 1)
     builder.append(new SimpleRecord("key".getBytes, "value".getBytes))
     builder.close()
 
     buffer.flip()
 
+    val epoch = 0
     val records = MemoryRecords.readableRecords(buffer)
-    records.batches.forEach(_.setPartitionLeaderEpoch(0))
+    records.batches.forEach(_.setPartitionLeaderEpoch(epoch))
 
     // Ensure that batches with duplicates are accepted on the follower.
     assertEquals(0L, log.logEndOffset)
-    log.appendAsFollower(records)
+    log.appendAsFollower(records, epoch)
     assertEquals(5L, log.logEndOffset)
   }
 
@@ -1582,8 +1654,12 @@ class UnifiedLogTest {
     val records = messageIds.map(id => new SimpleRecord(id.toString.getBytes))
 
     // now test the case that we give the offsets and use non-sequential 
offsets
-    for (i <- records.indices)
-      log.appendAsFollower(MemoryRecords.withRecords(messageIds(i), 
Compression.NONE, 0, records(i)))
+    for (i <- records.indices) {
+      log.appendAsFollower(
+        MemoryRecords.withRecords(messageIds(i), Compression.NONE, 0, 
records(i)),
+        Int.MaxValue
+      )
+    }
     for (i <- 50 until messageIds.max) {
       val idx = messageIds.indexWhere(_ >= i)
       val read = LogTestUtils.readLog(log, i, 
100).records.records.iterator.next()
@@ -1630,8 +1706,12 @@ class UnifiedLogTest {
     val records = messageIds.map(id => new SimpleRecord(id.toString.getBytes))
 
     // now test the case that we give the offsets and use non-sequential 
offsets
-    for (i <- records.indices)
-      log.appendAsFollower(MemoryRecords.withRecords(messageIds(i), 
Compression.NONE, 0, records(i)))
+    for (i <- records.indices) {
+      log.appendAsFollower(
+        MemoryRecords.withRecords(messageIds(i), Compression.NONE, 0, 
records(i)),
+        Int.MaxValue
+      )
+    }
 
     for (i <- 50 until messageIds.max) {
       val idx = messageIds.indexWhere(_ >= i)
@@ -1655,8 +1735,12 @@ class UnifiedLogTest {
     val records = messageIds.map(id => new SimpleRecord(id.toString.getBytes))
 
     // now test the case that we give the offsets and use non-sequential 
offsets
-    for (i <- records.indices)
-      log.appendAsFollower(MemoryRecords.withRecords(messageIds(i), 
Compression.NONE, 0, records(i)))
+    for (i <- records.indices) {
+      log.appendAsFollower(
+        MemoryRecords.withRecords(messageIds(i), Compression.NONE, 0, 
records(i)),
+        Int.MaxValue
+      )
+    }
 
     for (i <- 50 until messageIds.max) {
       assertEquals(MemoryRecords.EMPTY, LogTestUtils.readLog(log, i, maxLength 
= 0, minOneMessage = false).records)
@@ -1904,9 +1988,94 @@ class UnifiedLogTest {
 
     val log = createLog(logDir, LogTestUtils.createLogConfig(maxMessageBytes = 
second.sizeInBytes - 1))
 
-    log.appendAsFollower(first)
+    log.appendAsFollower(first, Int.MaxValue)
     // the second record is larger then limit but appendAsFollower does not 
validate the size.
-    log.appendAsFollower(second)
+    log.appendAsFollower(second, Int.MaxValue)
+  }
+
+  @ParameterizedTest
+  @ArgumentsSource(classOf[InvalidMemoryRecordsProvider])
+  def testInvalidMemoryRecords(records: MemoryRecords, expectedException: 
Optional[Class[Exception]]): Unit = {
+    val logConfig = LogTestUtils.createLogConfig()
+    val log = createLog(logDir, logConfig)
+    val previousEndOffset = log.logEndOffsetMetadata.messageOffset
+
+    if (expectedException.isPresent()) {
+      assertThrows(
+        expectedException.get(),
+        () => log.appendAsFollower(records, Int.MaxValue)
+      )
+    } else {
+        log.appendAsFollower(records, Int.MaxValue)
+    }
+
+    assertEquals(previousEndOffset, log.logEndOffsetMetadata.messageOffset)
+  }
+
+  @Property(tries = 100, afterFailure = AfterFailureMode.SAMPLE_ONLY)
+  def testRandomRecords(
+    @ForAll(supplier = classOf[ArbitraryMemoryRecords]) records: MemoryRecords
+  ): Unit = {
+    val tempDir = TestUtils.tempDir()
+    val logDir = TestUtils.randomPartitionLogDir(tempDir)
+    try {
+      val logConfig = LogTestUtils.createLogConfig()
+      val log = createLog(logDir, logConfig)
+      val previousEndOffset = log.logEndOffsetMetadata.messageOffset
+
+      // Depending on the corruption, unified log sometimes throws and 
sometimes returns an
+      // empty set of batches
+      assertThrows(
+        classOf[CorruptRecordException],
+        () => {
+          val info = log.appendAsFollower(records, Int.MaxValue)
+          if (info.firstOffset == JUnifiedLog.UNKNOWN_OFFSET) {
+            throw new CorruptRecordException("Unknown offset is test")
+          }
+        }
+      )
+
+      assertEquals(previousEndOffset, log.logEndOffsetMetadata.messageOffset)
+    } finally {
+      Utils.delete(tempDir)
+    }
+  }
+
+  @Test
+  def testInvalidLeaderEpoch(): Unit = {
+    val logConfig = LogTestUtils.createLogConfig()
+    val log = createLog(logDir, logConfig)
+    val previousEndOffset = log.logEndOffsetMetadata.messageOffset
+    val epoch = log.latestEpoch.getOrElse(0) + 1
+    val numberOfRecords = 10
+
+    val batchWithValidEpoch = MemoryRecords.withRecords(
+      previousEndOffset,
+      Compression.NONE,
+      epoch,
+      (0 until numberOfRecords).map(number => new 
SimpleRecord(number.toString.getBytes)): _*
+    )
+
+    val batchWithInvalidEpoch = MemoryRecords.withRecords(
+      previousEndOffset + numberOfRecords,
+      Compression.NONE,
+      epoch + 1,
+      (0 until numberOfRecords).map(number => new 
SimpleRecord(number.toString.getBytes)): _*
+    )
+
+    val buffer = ByteBuffer.allocate(batchWithValidEpoch.sizeInBytes() + 
batchWithInvalidEpoch.sizeInBytes())
+    buffer.put(batchWithValidEpoch.buffer())
+    buffer.put(batchWithInvalidEpoch.buffer())
+    buffer.flip()
+
+    val records = MemoryRecords.readableRecords(buffer)
+
+    log.appendAsFollower(records, epoch)
+
+    // Check that only the first batch was appended
+    assertEquals(previousEndOffset + numberOfRecords, 
log.logEndOffsetMetadata.messageOffset)
+    // Check that the last fetched epoch matches the first batch
+    assertEquals(epoch, log.latestEpoch.get)
   }
 
   @Test
@@ -1987,7 +2156,7 @@ class UnifiedLogTest {
     val messages = (0 until numMessages).map { i =>
       MemoryRecords.withRecords(100 + i, Compression.NONE, 0, new 
SimpleRecord(mockTime.milliseconds + i, i.toString.getBytes()))
     }
-    messages.foreach(log.appendAsFollower)
+    messages.foreach(message => log.appendAsFollower(message, Int.MaxValue))
     val timeIndexEntries = log.logSegments.asScala.foldLeft(0) { (entries, 
segment) => entries + segment.timeIndex.entries }
     assertEquals(numMessages - 1, timeIndexEntries, s"There should be 
${numMessages - 1} time index entries")
     assertEquals(mockTime.milliseconds + numMessages - 1, 
log.activeSegment.timeIndex.lastEntry.timestamp,
@@ -2131,7 +2300,7 @@ class UnifiedLogTest {
     // The cache can be updated directly after a leader change.
     // The new latest offset should reflect the updated epoch.
     log.assignEpochStartOffset(2, 2L)
-    
+
     assertEquals(new OffsetResultHolder(new 
TimestampAndOffset(ListOffsetsResponse.UNKNOWN_TIMESTAMP, 2L, Optional.of(2))),
       log.fetchOffsetByTimestamp(ListOffsetsRequest.LATEST_TIMESTAMP, 
Optional.of(remoteLogManager)))
   }
@@ -2399,20 +2568,22 @@ class UnifiedLogTest {
   def testAppendWithOutOfOrderOffsetsThrowsException(): Unit = {
     val log = createLog(logDir, new LogConfig(new Properties))
 
+    val epoch = 0
     val appendOffsets = Seq(0L, 1L, 3L, 2L, 4L)
     val buffer = ByteBuffer.allocate(512)
     for (offset <- appendOffsets) {
       val builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, 
Compression.NONE,
                                           TimestampType.LOG_APPEND_TIME, 
offset, mockTime.milliseconds(),
-                                          1L, 0, 0, false, 0)
+                                          1L, 0, 0, false, epoch)
       builder.append(new SimpleRecord("key".getBytes, "value".getBytes))
       builder.close()
     }
     buffer.flip()
     val memoryRecords = MemoryRecords.readableRecords(buffer)
 
-    assertThrows(classOf[OffsetsOutOfOrderException], () =>
-      log.appendAsFollower(memoryRecords)
+    assertThrows(
+      classOf[OffsetsOutOfOrderException],
+      () => log.appendAsFollower(memoryRecords, epoch)
     )
   }
 
@@ -2427,9 +2598,11 @@ class UnifiedLogTest {
     for (magic <- magicVals; compressionType <- compressionTypes) {
       val compression = Compression.of(compressionType).build()
       val invalidRecord = MemoryRecords.withRecords(magic, compression, new 
SimpleRecord(1.toString.getBytes))
-      assertThrows(classOf[UnexpectedAppendOffsetException],
-        () => log.appendAsFollower(invalidRecord),
-        () => s"Magic=$magic, compressionType=$compressionType")
+      assertThrows(
+        classOf[UnexpectedAppendOffsetException],
+        () => log.appendAsFollower(invalidRecord, Int.MaxValue),
+        () => s"Magic=$magic, compressionType=$compressionType"
+      )
     }
   }
 
@@ -2450,7 +2623,10 @@ class UnifiedLogTest {
                                     magicValue = magic, codec = 
Compression.of(compressionType).build(),
                                     baseOffset = firstOffset)
 
-      val exception = assertThrows(classOf[UnexpectedAppendOffsetException], 
() => log.appendAsFollower(records = batch))
+      val exception = assertThrows(
+        classOf[UnexpectedAppendOffsetException],
+        () => log.appendAsFollower(records = batch, Int.MaxValue)
+      )
       assertEquals(firstOffset, exception.firstOffset, s"Magic=$magic, 
compressionType=$compressionType, UnexpectedAppendOffsetException#firstOffset")
       assertEquals(firstOffset + 2, exception.lastOffset, s"Magic=$magic, 
compressionType=$compressionType, UnexpectedAppendOffsetException#lastOffset")
     }
@@ -2549,9 +2725,16 @@ class UnifiedLogTest {
     log.appendAsLeader(TestUtils.records(List(new 
SimpleRecord("foo".getBytes()))), leaderEpoch = 5)
     assertEquals(OptionalInt.of(5), log.leaderEpochCache.latestEpoch)
 
-    log.appendAsFollower(TestUtils.records(List(new 
SimpleRecord("foo".getBytes())),
-      baseOffset = 1L,
-      magicValue = RecordVersion.V1.value))
+    log.appendAsFollower(
+      TestUtils.records(
+        List(
+          new SimpleRecord("foo".getBytes())
+        ),
+        baseOffset = 1L,
+        magicValue = RecordVersion.V1.value
+      ),
+      5
+    )
     assertEquals(OptionalInt.empty, log.leaderEpochCache.latestEpoch)
   }
 
@@ -2907,7 +3090,7 @@ class UnifiedLogTest {
 
     //When appending as follower (assignOffsets = false)
     for (i <- records.indices)
-      log.appendAsFollower(recordsForEpoch(i))
+      log.appendAsFollower(recordsForEpoch(i), i)
 
     assertEquals(Some(42), log.latestEpoch)
   }
@@ -2975,7 +3158,7 @@ class UnifiedLogTest {
 
     def append(epoch: Int, startOffset: Long, count: Int): Unit = {
       for (i <- 0 until count)
-        log.appendAsFollower(createRecords(startOffset + i, epoch))
+        log.appendAsFollower(createRecords(startOffset + i, epoch), epoch)
     }
 
     //Given 2 segments, 10 messages per segment
@@ -3209,7 +3392,7 @@ class UnifiedLogTest {
 
     buffer.flip()
 
-    appendAsFollower(log, MemoryRecords.readableRecords(buffer))
+    appendAsFollower(log, MemoryRecords.readableRecords(buffer), epoch)
 
     val abortedTransactions = LogTestUtils.allAbortedTransactions(log)
     val expectedTransactions = List(
@@ -3293,7 +3476,7 @@ class UnifiedLogTest {
     appendEndTxnMarkerToBuffer(buffer, pid, epoch, 10L, 
ControlRecordType.COMMIT, leaderEpoch = 1)
 
     buffer.flip()
-    log.appendAsFollower(MemoryRecords.readableRecords(buffer))
+    log.appendAsFollower(MemoryRecords.readableRecords(buffer), epoch)
 
     LogTestUtils.appendEndTxnMarkerAsLeader(log, pid, epoch, 
ControlRecordType.ABORT, mockTime.milliseconds(), coordinatorEpoch = 2, 
leaderEpoch = 1)
     LogTestUtils.appendEndTxnMarkerAsLeader(log, pid, epoch, 
ControlRecordType.ABORT, mockTime.milliseconds(), coordinatorEpoch = 2, 
leaderEpoch = 1)
@@ -3414,10 +3597,16 @@ class UnifiedLogTest {
     val log = createLog(logDir, logConfig)
 
     // append a few records
-    appendAsFollower(log, MemoryRecords.withRecords(Compression.NONE,
-      new SimpleRecord("a".getBytes),
-      new SimpleRecord("b".getBytes),
-      new SimpleRecord("c".getBytes)), 5)
+    appendAsFollower(
+      log,
+      MemoryRecords.withRecords(
+        Compression.NONE,
+        new SimpleRecord("a".getBytes),
+        new SimpleRecord("b".getBytes),
+        new SimpleRecord("c".getBytes)
+      ),
+      5
+    )
 
 
     log.updateHighWatermark(3L)
@@ -4484,9 +4673,9 @@ class UnifiedLogTest {
     builder.close()
   }
 
-  private def appendAsFollower(log: UnifiedLog, records: MemoryRecords, 
leaderEpoch: Int = 0): Unit = {
+  private def appendAsFollower(log: UnifiedLog, records: MemoryRecords, 
leaderEpoch: Int): Unit = {
     records.batches.forEach(_.setPartitionLeaderEpoch(leaderEpoch))
-    log.appendAsFollower(records)
+    log.appendAsFollower(records, leaderEpoch)
   }
 
   private def createLog(dir: File,
diff --git 
a/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala 
b/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala
index 7528eefc420..d1a05e7d915 100644
--- a/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala
@@ -328,9 +328,12 @@ class AbstractFetcherManagerTest {
       fetchBackOffMs = 0,
       brokerTopicStats = new BrokerTopicStats) {
 
-    override protected def processPartitionData(topicPartition: 
TopicPartition, fetchOffset: Long, partitionData: FetchData): 
Option[LogAppendInfo] = {
-      None
-    }
+    override protected def processPartitionData(
+      topicPartition: TopicPartition,
+      fetchOffset: Long,
+      partitionLeaderEpoch: Int,
+      partitionData: FetchData
+    ): Option[LogAppendInfo] = None
 
     override protected def truncate(topicPartition: TopicPartition, 
truncationState: OffsetTruncationState): Unit = {}
 
diff --git 
a/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala 
b/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala
index 0e38e9dfcb0..19856e1cdd2 100644
--- a/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala
+++ b/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala
@@ -630,6 +630,7 @@ class AbstractFetcherThreadTest {
 
   @Test
   def testFollowerFetchOutOfRangeLow(): Unit = {
+    val leaderEpoch = 4
     val partition = new TopicPartition("topic", 0)
     val mockLeaderEndpoint = new MockLeaderEndPoint(version = version)
     val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint)
@@ -639,14 +640,19 @@ class AbstractFetcherThreadTest {
     val replicaLog = Seq(
       mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes)))
 
-    val replicaState = PartitionState(replicaLog, leaderEpoch = 0, 
highWatermark = 0L)
+    val replicaState = PartitionState(replicaLog, leaderEpoch = leaderEpoch, 
highWatermark = 0L)
     fetcher.setReplicaState(partition, replicaState)
-    fetcher.addPartitions(Map(partition -> 
initialFetchState(topicIds.get(partition.topic), 3L, leaderEpoch = 0)))
+    fetcher.addPartitions(
+      Map(
+        partition -> initialFetchState(topicIds.get(partition.topic), 3L, 
leaderEpoch = leaderEpoch)
+      )
+    )
 
     val leaderLog = Seq(
-      mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes)))
+      mkBatch(baseOffset = 2, leaderEpoch = leaderEpoch, new 
SimpleRecord("c".getBytes))
+    )
 
-    val leaderState = PartitionState(leaderLog, leaderEpoch = 0, highWatermark 
= 2L)
+    val leaderState = PartitionState(leaderLog, leaderEpoch = leaderEpoch, 
highWatermark = 2L)
     fetcher.mockLeader.setLeaderState(partition, leaderState)
     
fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState)
 
@@ -671,6 +677,7 @@ class AbstractFetcherThreadTest {
 
   @Test
   def testRetryAfterUnknownLeaderEpochInLatestOffsetFetch(): Unit = {
+    val leaderEpoch = 4
     val partition = new TopicPartition("topic", 0)
     val mockLeaderEndPoint = new MockLeaderEndPoint(version = version) {
       val tries = new AtomicInteger(0)
@@ -685,16 +692,18 @@ class AbstractFetcherThreadTest {
 
     // The follower begins from an offset which is behind the leader's log 
start offset
     val replicaLog = Seq(
-      mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes)))
+      mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes))
+    )
 
-    val replicaState = PartitionState(replicaLog, leaderEpoch = 0, 
highWatermark = 0L)
+    val replicaState = PartitionState(replicaLog, leaderEpoch = leaderEpoch, 
highWatermark = 0L)
     fetcher.setReplicaState(partition, replicaState)
-    fetcher.addPartitions(Map(partition -> 
initialFetchState(topicIds.get(partition.topic), 3L, leaderEpoch = 0)))
+    fetcher.addPartitions(Map(partition -> 
initialFetchState(topicIds.get(partition.topic), 3L, leaderEpoch = 
leaderEpoch)))
 
     val leaderLog = Seq(
-      mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes)))
+      mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes))
+    )
 
-    val leaderState = PartitionState(leaderLog, leaderEpoch = 0, highWatermark 
= 2L)
+    val leaderState = PartitionState(leaderLog, leaderEpoch = leaderEpoch, 
highWatermark = 2L)
     fetcher.mockLeader.setLeaderState(partition, leaderState)
     
fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState)
 
@@ -712,6 +721,46 @@ class AbstractFetcherThreadTest {
     assertEquals(leaderState.highWatermark, replicaState.highWatermark)
   }
 
+  @Test
+  def testReplicateBatchesUpToLeaderEpoch(): Unit = {
+    val leaderEpoch = 4
+    val partition = new TopicPartition("topic", 0)
+    val mockLeaderEndpoint = new MockLeaderEndPoint(version = version)
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint)
+    val fetcher = new MockFetcherThread(mockLeaderEndpoint, 
mockTierStateMachine, failedPartitions = failedPartitions)
+
+    val replicaState = PartitionState(Seq(), leaderEpoch = leaderEpoch, 
highWatermark = 0L)
+    fetcher.setReplicaState(partition, replicaState)
+    fetcher.addPartitions(
+      Map(
+        partition -> initialFetchState(topicIds.get(partition.topic), 3L, 
leaderEpoch = leaderEpoch)
+      )
+    )
+
+    val leaderLog = Seq(
+      mkBatch(baseOffset = 0, leaderEpoch = leaderEpoch - 1, new 
SimpleRecord("c".getBytes)),
+      mkBatch(baseOffset = 1, leaderEpoch = leaderEpoch, new 
SimpleRecord("d".getBytes)),
+      mkBatch(baseOffset = 2, leaderEpoch = leaderEpoch + 1, new 
SimpleRecord("e".getBytes))
+    )
+
+    val leaderState = PartitionState(leaderLog, leaderEpoch = leaderEpoch, 
highWatermark = 0L)
+    fetcher.mockLeader.setLeaderState(partition, leaderState)
+    
fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState)
+
+    assertEquals(Option(Fetching), fetcher.fetchState(partition).map(_.state))
+    assertEquals(0, replicaState.logStartOffset)
+    assertEquals(List(), replicaState.log.toList)
+
+    TestUtils.waitUntilTrue(() => {
+      fetcher.doWork()
+      fetcher.replicaPartitionState(partition).log == 
fetcher.mockLeader.leaderPartitionState(partition).log.dropRight(1)
+    }, "Failed to reconcile leader and follower logs up to the leader epoch")
+
+    assertEquals(leaderState.logStartOffset, replicaState.logStartOffset)
+    assertEquals(leaderState.logEndOffset - 1, replicaState.logEndOffset)
+    assertEquals(leaderState.highWatermark, replicaState.highWatermark)
+  }
+
   @Test
   def testCorruptMessage(): Unit = {
     val partition = new TopicPartition("topic", 0)
@@ -897,11 +946,16 @@ class AbstractFetcherThreadTest {
     val mockLeaderEndpoint = new MockLeaderEndPoint(version = version)
     val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint)
     val fetcherForAppend = new MockFetcherThread(mockLeaderEndpoint, 
mockTierStateMachine, failedPartitions = failedPartitions) {
-      override def processPartitionData(topicPartition: TopicPartition, 
fetchOffset: Long, partitionData: FetchData): Option[LogAppendInfo] = {
+      override def processPartitionData(
+        topicPartition: TopicPartition,
+        fetchOffset: Long,
+        partitionLeaderEpoch: Int,
+        partitionData: FetchData
+      ): Option[LogAppendInfo] = {
         if (topicPartition == partition1) {
           throw new KafkaException()
         } else {
-          super.processPartitionData(topicPartition, fetchOffset, 
partitionData)
+          super.processPartitionData(topicPartition, fetchOffset, 
partitionLeaderEpoch, partitionData)
         }
       }
     }
@@ -1003,9 +1057,14 @@ class AbstractFetcherThreadTest {
     val mockLeaderEndpoint = new MockLeaderEndPoint(version = version)
     val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint)
     val fetcher = new MockFetcherThread(mockLeaderEndpoint, 
mockTierStateMachine) {
-      override def processPartitionData(topicPartition: TopicPartition, 
fetchOffset: Long, partitionData: FetchData): Option[LogAppendInfo] = {
+      override def processPartitionData(
+        topicPartition: TopicPartition,
+        fetchOffset: Long,
+        partitionLeaderEpoch: Int,
+        partitionData: FetchData
+      ): Option[LogAppendInfo] = {
         processPartitionDataCalls += 1
-        super.processPartitionData(topicPartition, fetchOffset, partitionData)
+        super.processPartitionData(topicPartition, fetchOffset, 
partitionLeaderEpoch, partitionData)
       }
 
       override def truncate(topicPartition: TopicPartition, truncationState: 
OffsetTruncationState): Unit = {
diff --git a/core/src/test/scala/unit/kafka/server/MockFetcherThread.scala 
b/core/src/test/scala/unit/kafka/server/MockFetcherThread.scala
index 5d50de04095..ff1e9196568 100644
--- a/core/src/test/scala/unit/kafka/server/MockFetcherThread.scala
+++ b/core/src/test/scala/unit/kafka/server/MockFetcherThread.scala
@@ -66,9 +66,12 @@ class MockFetcherThread(val mockLeader: MockLeaderEndPoint,
     partitions
   }
 
-  override def processPartitionData(topicPartition: TopicPartition,
-                                    fetchOffset: Long,
-                                    partitionData: FetchData): 
Option[LogAppendInfo] = {
+  override def processPartitionData(
+    topicPartition: TopicPartition,
+    fetchOffset: Long,
+    leaderEpochForReplica: Int,
+    partitionData: FetchData
+  ): Option[LogAppendInfo] = {
     val state = replicaPartitionState(topicPartition)
 
     if (leader.isTruncationOnFetchSupported && 
FetchResponse.isDivergingEpoch(partitionData)) {
@@ -87,17 +90,24 @@ class MockFetcherThread(val mockLeader: MockLeaderEndPoint,
     var shallowOffsetOfMaxTimestamp = -1L
     var lastOffset = state.logEndOffset
     var lastEpoch: OptionalInt = OptionalInt.empty()
+    var skipRemainingBatches = false
 
     for (batch <- batches) {
       batch.ensureValid()
-      if (batch.maxTimestamp > maxTimestamp) {
-        maxTimestamp = batch.maxTimestamp
-        shallowOffsetOfMaxTimestamp = batch.baseOffset
+
+      skipRemainingBatches = skipRemainingBatches || 
hasHigherPartitionLeaderEpoch(batch, leaderEpochForReplica);
+      if (skipRemainingBatches) {
+        info(s"Skipping batch $batch because leader epoch is 
$leaderEpochForReplica")
+      } else {
+        if (batch.maxTimestamp > maxTimestamp) {
+          maxTimestamp = batch.maxTimestamp
+          shallowOffsetOfMaxTimestamp = batch.baseOffset
+        }
+        state.log.append(batch)
+        state.logEndOffset = batch.nextOffset
+        lastOffset = batch.lastOffset
+        lastEpoch = OptionalInt.of(batch.partitionLeaderEpoch)
       }
-      state.log.append(batch)
-      state.logEndOffset = batch.nextOffset
-      lastOffset = batch.lastOffset
-      lastEpoch = OptionalInt.of(batch.partitionLeaderEpoch)
     }
 
     state.logStartOffset = partitionData.logStartOffset
@@ -115,6 +125,11 @@ class MockFetcherThread(val mockLeader: MockLeaderEndPoint,
       batches.headOption.map(_.lastOffset).getOrElse(-1)))
   }
 
+  private def hasHigherPartitionLeaderEpoch(batch: RecordBatch, leaderEpoch: 
Int): Boolean = {
+    batch.partitionLeaderEpoch() != RecordBatch.NO_PARTITION_LEADER_EPOCH &&
+    batch.partitionLeaderEpoch() > leaderEpoch
+  }
+
   override def truncate(topicPartition: TopicPartition, truncationState: 
OffsetTruncationState): Unit = {
     val state = replicaPartitionState(topicPartition)
     state.log = state.log.takeWhile { batch =>
diff --git 
a/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala 
b/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala
index 6526d6628c3..b0ee5a2d148 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala
@@ -281,9 +281,22 @@ class ReplicaFetcherThreadTest {
     val fetchSessionHandler = new FetchSessionHandler(logContext, 
brokerEndPoint.id)
     val leader = new RemoteLeaderEndPoint(logContext.logPrefix, mockNetwork, 
fetchSessionHandler, config,
       replicaManager, quota, () => MetadataVersion.MINIMUM_VERSION, () => 1)
-    val thread = new ReplicaFetcherThread("bob", leader, config, 
failedPartitions,
-      replicaManager, quota, logContext.logPrefix, () => 
MetadataVersion.MINIMUM_VERSION) {
-      override def processPartitionData(topicPartition: TopicPartition, 
fetchOffset: Long, partitionData: FetchData): Option[LogAppendInfo] = None
+    val thread = new ReplicaFetcherThread(
+      "bob",
+      leader,
+      config,
+      failedPartitions,
+      replicaManager,
+      quota,
+      logContext.logPrefix,
+      () => MetadataVersion.MINIMUM_VERSION
+    ) {
+      override def processPartitionData(
+        topicPartition: TopicPartition,
+        fetchOffset: Long,
+        partitionLeaderEpoch: Int,
+        partitionData: FetchData
+      ): Option[LogAppendInfo] = None
     }
     thread.addPartitions(Map(t1p0 -> initialFetchState(Some(topicId1), 
initialLEO), t1p1 -> initialFetchState(Some(topicId1), initialLEO)))
     val partitions = Set(t1p0, t1p1)
@@ -379,7 +392,7 @@ class ReplicaFetcherThreadTest {
     when(replicaManager.getPartitionOrException(t1p0)).thenReturn(partition)
 
     when(partition.localLogOrException).thenReturn(log)
-    when(partition.appendRecordsToFollowerOrFutureReplica(any(), 
any())).thenReturn(None)
+    when(partition.appendRecordsToFollowerOrFutureReplica(any(), any(), 
any())).thenReturn(None)
 
     val logContext = new LogContext(s"[ReplicaFetcher 
replicaId=${config.brokerId}, leaderId=${brokerEndPoint.id}, fetcherId=0] ")
 
@@ -460,7 +473,7 @@ class ReplicaFetcherThreadTest {
     
when(replicaManager.brokerTopicStats).thenReturn(mock(classOf[BrokerTopicStats]))
 
     when(partition.localLogOrException).thenReturn(log)
-    when(partition.appendRecordsToFollowerOrFutureReplica(any(), 
any())).thenReturn(Some(new LogAppendInfo(
+    when(partition.appendRecordsToFollowerOrFutureReplica(any(), any(), 
any())).thenReturn(Some(new LogAppendInfo(
       -1,
       0,
       OptionalInt.empty,
@@ -679,7 +692,7 @@ class ReplicaFetcherThreadTest {
 
     val partition: Partition = mock(classOf[Partition])
     when(partition.localLogOrException).thenReturn(log)
-    when(partition.appendRecordsToFollowerOrFutureReplica(any[MemoryRecords], 
any[Boolean])).thenReturn(appendInfo)
+    when(partition.appendRecordsToFollowerOrFutureReplica(any[MemoryRecords], 
any[Boolean], any[Int])).thenReturn(appendInfo)
 
     // Capture the argument at the time of invocation.
     val completeDelayedFetchRequestsArgument = 
mutable.Buffer.empty[TopicPartition]
@@ -710,8 +723,8 @@ class ReplicaFetcherThreadTest {
       .setRecords(records)
       .setHighWatermark(highWatermarkReceivedFromLeader)
 
-    thread.processPartitionData(tp0, 0, partitionData.setPartitionIndex(0))
-    thread.processPartitionData(tp1, 0, partitionData.setPartitionIndex(1))
+    thread.processPartitionData(tp0, 0, Int.MaxValue, 
partitionData.setPartitionIndex(0))
+    thread.processPartitionData(tp1, 0, Int.MaxValue, 
partitionData.setPartitionIndex(1))
     verify(replicaManager, 
times(0)).completeDelayedFetchRequests(any[Seq[TopicPartition]])
 
     thread.doWork()
@@ -761,7 +774,7 @@ class ReplicaFetcherThreadTest {
     when(partition.localLogOrException).thenReturn(log)
     when(partition.isReassigning).thenReturn(isReassigning)
     when(partition.isAddingLocalReplica).thenReturn(isReassigning)
-    when(partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = 
false)).thenReturn(None)
+    when(partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = 
false, Int.MaxValue)).thenReturn(None)
 
     val replicaManager: ReplicaManager = mock(classOf[ReplicaManager])
     
when(replicaManager.getPartitionOrException(any[TopicPartition])).thenReturn(partition)
@@ -785,7 +798,7 @@ class ReplicaFetcherThreadTest {
       .setLastStableOffset(0)
       .setLogStartOffset(0)
       .setRecords(records)
-    thread.processPartitionData(t1p0, 0, partitionData)
+    thread.processPartitionData(t1p0, 0, Int.MaxValue, partitionData)
 
     if (isReassigning)
       assertEquals(records.sizeInBytes(), 
brokerTopicStats.allTopicsStats.reassignmentBytesInPerSec.get.count())
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala 
b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
index 6a27776babc..59d9b4b1a63 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
@@ -5253,9 +5253,12 @@ class ReplicaManagerTest {
         replicaManager.getPartition(topicPartition) match {
           case HostedPartition.Online(partition) =>
             partition.appendRecordsToFollowerOrFutureReplica(
-              records = MemoryRecords.withRecords(Compression.NONE, 0,
-                new SimpleRecord("first message".getBytes)),
-              isFuture = false
+              records = MemoryRecords.withRecords(
+                Compression.NONE, 0,
+                new SimpleRecord("first message".getBytes)
+              ),
+              isFuture = false,
+              partitionLeaderEpoch = 0
             )
 
           case _ =>
diff --git 
a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetcher/ReplicaFetcherThreadBenchmark.java
 
b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetcher/ReplicaFetcherThreadBenchmark.java
index d6bec0c8016..d9091bd5b57 100644
--- 
a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetcher/ReplicaFetcherThreadBenchmark.java
+++ 
b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetcher/ReplicaFetcherThreadBenchmark.java
@@ -335,8 +335,12 @@ public class ReplicaFetcherThreadBenchmark {
         }
 
         @Override
-        public Option<LogAppendInfo> processPartitionData(TopicPartition 
topicPartition, long fetchOffset,
-                                                          
FetchResponseData.PartitionData partitionData) {
+        public Option<LogAppendInfo> processPartitionData(
+            TopicPartition topicPartition,
+            long fetchOffset,
+            int partitionLeaderEpoch,
+            FetchResponseData.PartitionData partitionData
+        ) {
             return Option.empty();
         }
     }
diff --git 
a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/partition/PartitionMakeFollowerBenchmark.java
 
b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/partition/PartitionMakeFollowerBenchmark.java
index feb8c985904..a345f3907b8 100644
--- 
a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/partition/PartitionMakeFollowerBenchmark.java
+++ 
b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/partition/PartitionMakeFollowerBenchmark.java
@@ -134,7 +134,7 @@ public class PartitionMakeFollowerBenchmark {
             int initialOffSet = 0;
             while (true) {
                 MemoryRecords memoryRecords =  
MemoryRecords.withRecords(initialOffSet, Compression.NONE, 0, simpleRecords);
-                
partition.appendRecordsToFollowerOrFutureReplica(memoryRecords, false);
+                
partition.appendRecordsToFollowerOrFutureReplica(memoryRecords, false, 
Integer.MAX_VALUE);
                 initialOffSet = initialOffSet + 2;
             }
         });
diff --git a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java 
b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
index 538eec64e91..34b5770cf70 100644
--- a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
+++ b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.raft;
 
+import org.apache.kafka.common.InvalidRecordException;
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.Node;
 import org.apache.kafka.common.TopicPartition;
@@ -23,6 +24,7 @@ import org.apache.kafka.common.Uuid;
 import org.apache.kafka.common.compress.Compression;
 import org.apache.kafka.common.config.ConfigException;
 import org.apache.kafka.common.errors.ClusterAuthorizationException;
+import org.apache.kafka.common.errors.CorruptRecordException;
 import org.apache.kafka.common.errors.NotLeaderOrFollowerException;
 import org.apache.kafka.common.feature.SupportedVersionRange;
 import org.apache.kafka.common.memory.MemoryPool;
@@ -50,6 +52,7 @@ import org.apache.kafka.common.network.ListenerName;
 import org.apache.kafka.common.protocol.ApiKeys;
 import org.apache.kafka.common.protocol.ApiMessage;
 import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.record.DefaultRecordBatch;
 import org.apache.kafka.common.record.MemoryRecords;
 import org.apache.kafka.common.record.Records;
 import org.apache.kafka.common.record.UnalignedMemoryRecords;
@@ -93,8 +96,10 @@ import org.apache.kafka.snapshot.SnapshotWriter;
 import org.slf4j.Logger;
 
 import java.net.InetSocketAddress;
+import java.nio.ByteBuffer;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.HexFormat;
 import java.util.IdentityHashMap;
 import java.util.Iterator;
 import java.util.List;
@@ -1785,10 +1790,7 @@ public final class KafkaRaftClient<T> implements 
RaftClient<T> {
                     }
                 }
             } else {
-                Records records = 
FetchResponse.recordsOrFail(partitionResponse);
-                if (records.sizeInBytes() > 0) {
-                    appendAsFollower(records);
-                }
+                
appendAsFollower(FetchResponse.recordsOrFail(partitionResponse));
 
                 OptionalLong highWatermark = partitionResponse.highWatermark() 
< 0 ?
                     OptionalLong.empty() : 
OptionalLong.of(partitionResponse.highWatermark());
@@ -1802,10 +1804,31 @@ public final class KafkaRaftClient<T> implements 
RaftClient<T> {
         }
     }
 
-    private void appendAsFollower(
-        Records records
-    ) {
-        LogAppendInfo info = log.appendAsFollower(records);
+    private static String convertToHexadecimal(Records records) {
+        ByteBuffer buffer = ((MemoryRecords) records).buffer();
+        byte[] bytes = new byte[Math.min(buffer.remaining(), 
DefaultRecordBatch.RECORD_BATCH_OVERHEAD)];
+        buffer.get(bytes);
+
+        return HexFormat.of().formatHex(bytes);
+    }
+
+    private void appendAsFollower(Records records) {
+        if (records.sizeInBytes() == 0) {
+            // Nothing to do if there are no bytes in the response
+            return;
+        }
+
+        try {
+            var info = log.appendAsFollower(records, quorum.epoch());
+            kafkaRaftMetrics.updateFetchedRecords(info.lastOffset - 
info.firstOffset + 1);
+        } catch (CorruptRecordException | InvalidRecordException e) {
+            logger.info(
+                "Failed to append the records with the batch header '{}' to 
the log",
+                convertToHexadecimal(records),
+                e
+            );
+        }
+
         if (quorum.isVoter() || followersAlwaysFlush) {
             // the leader only requires that voters have flushed their log 
before sending a Fetch
             // request. Because of reconfiguration some observers (that are 
getting added to the
@@ -1817,14 +1840,11 @@ public final class KafkaRaftClient<T> implements 
RaftClient<T> {
         partitionState.updateState();
 
         OffsetAndEpoch endOffset = endOffset();
-        kafkaRaftMetrics.updateFetchedRecords(info.lastOffset - 
info.firstOffset + 1);
         kafkaRaftMetrics.updateLogEnd(endOffset);
         logger.trace("Follower end offset updated to {} after append", 
endOffset);
     }
 
-    private LogAppendInfo appendAsLeader(
-        Records records
-    ) {
+    private LogAppendInfo appendAsLeader(Records records) {
         LogAppendInfo info = log.appendAsLeader(records, quorum.epoch());
 
         partitionState.updateState();
@@ -3475,6 +3495,10 @@ public final class KafkaRaftClient<T> implements 
RaftClient<T> {
             () -> new NotLeaderException("Append failed because the replica is 
not the current leader")
         );
 
+        if (records.isEmpty()) {
+            throw new IllegalArgumentException("Append failed because there 
are no records");
+        }
+
         BatchAccumulator<T> accumulator = leaderState.accumulator();
         boolean isFirstAppend = accumulator.isEmpty();
         final long offset = accumulator.append(epoch, records, true);
diff --git a/raft/src/main/java/org/apache/kafka/raft/ReplicatedLog.java 
b/raft/src/main/java/org/apache/kafka/raft/ReplicatedLog.java
index a22f7fd73cd..8f5ba31a45d 100644
--- a/raft/src/main/java/org/apache/kafka/raft/ReplicatedLog.java
+++ b/raft/src/main/java/org/apache/kafka/raft/ReplicatedLog.java
@@ -31,6 +31,8 @@ public interface ReplicatedLog extends AutoCloseable {
      * be written atomically in a single batch or the call will fail and raise 
an
      * exception.
      *
+     * @param records records batches to append
+     * @param epoch the epoch of the replica
      * @return the metadata information of the appended batch
      * @throws IllegalArgumentException if the record set is empty
      * @throws RuntimeException if the batch base offset doesn't match the log 
end offset
@@ -42,11 +44,16 @@ public interface ReplicatedLog extends AutoCloseable {
      * difference from appendAsLeader is that we do not need to assign the 
epoch
      * or do additional validation.
      *
+     * The log will append record batches up to and including batches that 
have a partition
+     * leader epoch less than or equal to the passed epoch.
+     *
+     * @param records records batches to append
+     * @param epoch the epoch of the replica
      * @return the metadata information of the appended batch
      * @throws IllegalArgumentException if the record set is empty
      * @throws RuntimeException if the batch base offset doesn't match the log 
end offset
      */
-    LogAppendInfo appendAsFollower(Records records);
+    LogAppendInfo appendAsFollower(Records records, int epoch);
 
     /**
      * Read a set of records within a range of offsets.
diff --git 
a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientFetchTest.java 
b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientFetchTest.java
new file mode 100644
index 00000000000..ade509d8051
--- /dev/null
+++ b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientFetchTest.java
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.raft;
+
+import org.apache.kafka.common.compress.Compression;
+import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.record.ArbitraryMemoryRecords;
+import org.apache.kafka.common.record.InvalidMemoryRecordsProvider;
+import org.apache.kafka.common.record.MemoryRecords;
+import org.apache.kafka.common.record.SimpleRecord;
+import org.apache.kafka.server.common.KRaftVersion;
+
+import net.jqwik.api.AfterFailureMode;
+import net.jqwik.api.ForAll;
+import net.jqwik.api.Property;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ArgumentsSource;
+
+import java.nio.ByteBuffer;
+import java.util.Optional;
+import java.util.stream.IntStream;
+import java.util.stream.Stream;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public final class KafkaRaftClientFetchTest {
+    @Property(tries = 100, afterFailure = AfterFailureMode.SAMPLE_ONLY)
+    void testRandomRecords(
+        @ForAll(supplier = ArbitraryMemoryRecords.class) MemoryRecords 
memoryRecords
+    ) throws Exception {
+        testFetchResponseWithInvalidRecord(memoryRecords, Integer.MAX_VALUE);
+    }
+
+    @ParameterizedTest
+    @ArgumentsSource(InvalidMemoryRecordsProvider.class)
+    void testInvalidMemoryRecords(MemoryRecords records, 
Optional<Class<Exception>> expectedException) throws Exception {
+        // CorruptRecordException are handled by the KafkaRaftClient so ignore 
the expected exception
+        testFetchResponseWithInvalidRecord(records, Integer.MAX_VALUE);
+    }
+
+    private static void testFetchResponseWithInvalidRecord(MemoryRecords 
records, int epoch) throws Exception {
+        int localId = KafkaRaftClientTest.randomReplicaId();
+        ReplicaKey local = KafkaRaftClientTest.replicaKey(localId, true);
+        ReplicaKey electedLeader = KafkaRaftClientTest.replicaKey(localId + 1, 
true);
+
+        RaftClientTestContext context = new RaftClientTestContext.Builder(
+            local.id(),
+            local.directoryId().get()
+        )
+            .withStartingVoters(
+                VoterSetTest.voterSet(Stream.of(local, electedLeader)), 
KRaftVersion.KRAFT_VERSION_1
+            )
+            .withElectedLeader(epoch, electedLeader.id())
+            
.withRaftProtocol(RaftClientTestContext.RaftProtocol.KIP_996_PROTOCOL)
+            .build();
+
+        context.pollUntilRequest();
+        RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest();
+        context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
+
+        long oldLogEndOffset = context.log.endOffset().offset();
+
+        context.deliverResponse(
+            fetchRequest.correlationId(),
+            fetchRequest.destination(),
+            context.fetchResponse(epoch, electedLeader.id(), records, 0L, 
Errors.NONE)
+        );
+
+        context.client.poll();
+
+        assertEquals(oldLogEndOffset, context.log.endOffset().offset());
+    }
+
+    @Test
+    void testReplicationOfHigherPartitionLeaderEpoch() throws Exception {
+        int epoch = 2;
+        int localId = KafkaRaftClientTest.randomReplicaId();
+        ReplicaKey local = KafkaRaftClientTest.replicaKey(localId, true);
+        ReplicaKey electedLeader = KafkaRaftClientTest.replicaKey(localId + 1, 
true);
+
+        RaftClientTestContext context = new RaftClientTestContext.Builder(
+            local.id(),
+            local.directoryId().get()
+        )
+            .withStartingVoters(
+                VoterSetTest.voterSet(Stream.of(local, electedLeader)), 
KRaftVersion.KRAFT_VERSION_1
+            )
+            .withElectedLeader(epoch, electedLeader.id())
+            
.withRaftProtocol(RaftClientTestContext.RaftProtocol.KIP_996_PROTOCOL)
+            .build();
+
+        context.pollUntilRequest();
+        RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest();
+        context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
+
+        long oldLogEndOffset = context.log.endOffset().offset();
+        int numberOfRecords = 10;
+        MemoryRecords batchWithValidEpoch = MemoryRecords.withRecords(
+            oldLogEndOffset,
+            Compression.NONE,
+            epoch,
+            IntStream
+                .range(0, numberOfRecords)
+                .mapToObj(number -> new 
SimpleRecord(Integer.toString(number).getBytes()))
+                .toArray(SimpleRecord[]::new)
+        );
+
+        MemoryRecords batchWithInvalidEpoch = MemoryRecords.withRecords(
+            oldLogEndOffset + numberOfRecords,
+            Compression.NONE,
+            epoch + 1,
+            IntStream
+                .range(0, numberOfRecords)
+                .mapToObj(number -> new 
SimpleRecord(Integer.toString(number).getBytes()))
+                .toArray(SimpleRecord[]::new)
+        );
+
+        var buffer = ByteBuffer.allocate(batchWithValidEpoch.sizeInBytes() + 
batchWithInvalidEpoch.sizeInBytes());
+        buffer.put(batchWithValidEpoch.buffer());
+        buffer.put(batchWithInvalidEpoch.buffer());
+        buffer.flip();
+
+        MemoryRecords records = MemoryRecords.readableRecords(buffer);
+
+        context.deliverResponse(
+            fetchRequest.correlationId(),
+            fetchRequest.destination(),
+            context.fetchResponse(epoch, electedLeader.id(), records, 0L, 
Errors.NONE)
+        );
+
+        context.client.poll();
+
+        // Check that only the first batch was appended because the second 
batch has a greater epoch
+        assertEquals(oldLogEndOffset + numberOfRecords, 
context.log.endOffset().offset());
+    }
+}
diff --git a/raft/src/test/java/org/apache/kafka/raft/MockLog.java 
b/raft/src/test/java/org/apache/kafka/raft/MockLog.java
index a7a8e89a88c..9fb4724cc0c 100644
--- a/raft/src/test/java/org/apache/kafka/raft/MockLog.java
+++ b/raft/src/test/java/org/apache/kafka/raft/MockLog.java
@@ -19,6 +19,7 @@ package org.apache.kafka.raft;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.Uuid;
 import org.apache.kafka.common.compress.Compression;
+import org.apache.kafka.common.errors.CorruptRecordException;
 import org.apache.kafka.common.errors.OffsetOutOfRangeException;
 import org.apache.kafka.common.record.MemoryRecords;
 import org.apache.kafka.common.record.MemoryRecordsBuilder;
@@ -279,7 +280,7 @@ public class MockLog implements ReplicatedLog {
 
     @Override
     public LogAppendInfo appendAsLeader(Records records, int epoch) {
-        return append(records, OptionalInt.of(epoch));
+        return append(records, epoch, true);
     }
 
     private long appendBatch(LogBatch batch) {
@@ -292,16 +293,18 @@ public class MockLog implements ReplicatedLog {
     }
 
     @Override
-    public LogAppendInfo appendAsFollower(Records records) {
-        return append(records, OptionalInt.empty());
+    public LogAppendInfo appendAsFollower(Records records, int epoch) {
+        return append(records, epoch, false);
     }
 
-    private LogAppendInfo append(Records records, OptionalInt epoch) {
-        if (records.sizeInBytes() == 0)
+    private LogAppendInfo append(Records records, int epoch, boolean isLeader) 
{
+        if (records.sizeInBytes() == 0) {
             throw new IllegalArgumentException("Attempt to append an empty 
record set");
+        }
 
         long baseOffset = endOffset().offset();
         long lastOffset = baseOffset;
+        boolean hasBatches = false;
         for (RecordBatch batch : records.batches()) {
             if (batch.baseOffset() != endOffset().offset()) {
                 /* KafkaMetadataLog throws an 
kafka.common.UnexpectedAppendOffsetException this is the
@@ -314,26 +317,47 @@ public class MockLog implements ReplicatedLog {
                         endOffset().offset()
                     )
                 );
+            } else if (isLeader && epoch != batch.partitionLeaderEpoch()) {
+                // the partition leader epoch is set and does not match the 
one set in the batch
+                throw new RuntimeException(
+                    String.format(
+                        "Epoch %s doesn't match batch leader epoch %s",
+                        epoch,
+                        batch.partitionLeaderEpoch()
+                    )
+                );
+            } else if (!isLeader && batch.partitionLeaderEpoch() > epoch) {
+                /* To avoid inconsistent log replication, follower should only 
append record
+                 * batches with an epoch less than or equal to the leader 
epoch. There is more
+                 * details on this issue and scenario in KAFKA-18723.
+                 */
+                break;
             }
 
+            hasBatches = true;
             LogBatch logBatch = new LogBatch(
-                epoch.orElseGet(batch::partitionLeaderEpoch),
+                batch.partitionLeaderEpoch(),
                 batch.isControlBatch(),
                 buildEntries(batch, Record::offset)
             );
 
             if (logger.isDebugEnabled()) {
-                String nodeState = "Follower";
-                if (epoch.isPresent()) {
-                    nodeState = "Leader";
-                }
-                logger.debug("{} appending to the log {}", nodeState, 
logBatch);
+                logger.debug(
+                    "{} appending to the log {}",
+                    isLeader ? "Leader" : "Follower",
+                    logBatch
+                );
             }
 
             appendBatch(logBatch);
             lastOffset = logBatch.last().offset;
         }
 
+        if (!hasBatches) {
+            // This emulates the default handling when records doesn't have 
enough bytes for a batch
+            throw new CorruptRecordException("Append failed unexpectedly");
+        }
+
         return new LogAppendInfo(baseOffset, lastOffset);
     }
 
diff --git a/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java 
b/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java
index 8306e103258..eca0fe5d3de 100644
--- a/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java
@@ -19,9 +19,12 @@ package org.apache.kafka.raft;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.Uuid;
 import org.apache.kafka.common.compress.Compression;
+import org.apache.kafka.common.errors.CorruptRecordException;
 import org.apache.kafka.common.errors.OffsetOutOfRangeException;
 import org.apache.kafka.common.message.LeaderChangeMessage;
+import org.apache.kafka.common.record.ArbitraryMemoryRecords;
 import org.apache.kafka.common.record.ControlRecordUtils;
+import org.apache.kafka.common.record.InvalidMemoryRecordsProvider;
 import org.apache.kafka.common.record.MemoryRecords;
 import org.apache.kafka.common.record.Record;
 import org.apache.kafka.common.record.RecordBatch;
@@ -32,9 +35,16 @@ import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.snapshot.RawSnapshotReader;
 import org.apache.kafka.snapshot.RawSnapshotWriter;
 
+import net.jqwik.api.AfterFailureMode;
+import net.jqwik.api.ForAll;
+import net.jqwik.api.Property;
+
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.function.Executable;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ArgumentsSource;
 
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
@@ -44,6 +54,7 @@ import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
 import java.util.Optional;
+import java.util.stream.IntStream;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
@@ -169,14 +180,17 @@ public class MockLogTest {
         assertThrows(
             RuntimeException.class,
             () -> log.appendAsLeader(
-                    MemoryRecords.withRecords(initialOffset, Compression.NONE, 
currentEpoch, recordFoo),
-                    currentEpoch)
+                MemoryRecords.withRecords(initialOffset, Compression.NONE, 
currentEpoch, recordFoo),
+                currentEpoch
+            )
         );
 
         assertThrows(
             RuntimeException.class,
             () -> log.appendAsFollower(
-                    MemoryRecords.withRecords(initialOffset, Compression.NONE, 
currentEpoch, recordFoo))
+                MemoryRecords.withRecords(initialOffset, Compression.NONE, 
currentEpoch, recordFoo),
+                currentEpoch
+            )
         );
     }
 
@@ -187,7 +201,13 @@ public class MockLogTest {
         LeaderChangeMessage messageData =  new 
LeaderChangeMessage().setLeaderId(0);
         ByteBuffer buffer = ByteBuffer.allocate(256);
         log.appendAsLeader(
-            MemoryRecords.withLeaderChangeMessage(initialOffset, 0L, 2, 
buffer, messageData),
+            MemoryRecords.withLeaderChangeMessage(
+                initialOffset,
+                0L,
+                currentEpoch,
+                buffer,
+                messageData
+            ),
             currentEpoch
         );
 
@@ -221,7 +241,10 @@ public class MockLogTest {
         }
         log.truncateToLatestSnapshot();
 
-        log.appendAsFollower(MemoryRecords.withRecords(initialOffset, 
Compression.NONE, epoch, recordFoo));
+        log.appendAsFollower(
+            MemoryRecords.withRecords(initialOffset, Compression.NONE, epoch, 
recordFoo),
+            epoch
+        );
 
         assertEquals(initialOffset, log.startOffset());
         assertEquals(initialOffset + 1, log.endOffset().offset());
@@ -368,10 +391,82 @@ public class MockLogTest {
 
     @Test
     public void testEmptyAppendNotAllowed() {
-        assertThrows(IllegalArgumentException.class, () -> 
log.appendAsFollower(MemoryRecords.EMPTY));
+        assertThrows(IllegalArgumentException.class, () -> 
log.appendAsFollower(MemoryRecords.EMPTY, 1));
         assertThrows(IllegalArgumentException.class, () -> 
log.appendAsLeader(MemoryRecords.EMPTY, 1));
     }
 
+    @ParameterizedTest
+    @ArgumentsSource(InvalidMemoryRecordsProvider.class)
+    void testInvalidMemoryRecords(MemoryRecords records, 
Optional<Class<Exception>> expectedException) {
+        long previousEndOffset = log.endOffset().offset();
+
+        Executable action = () -> log.appendAsFollower(records, 
Integer.MAX_VALUE);
+        if (expectedException.isPresent()) {
+            assertThrows(expectedException.get(), action);
+        } else {
+            assertThrows(CorruptRecordException.class, action);
+        }
+
+        assertEquals(previousEndOffset, log.endOffset().offset());
+    }
+
+    @Property(tries = 100, afterFailure = AfterFailureMode.SAMPLE_ONLY)
+    void testRandomRecords(
+        @ForAll(supplier = ArbitraryMemoryRecords.class) MemoryRecords records
+    ) {
+        try (MockLog log = new MockLog(topicPartition, topicId, new 
LogContext())) {
+            long previousEndOffset = log.endOffset().offset();
+
+            assertThrows(
+                CorruptRecordException.class,
+                () -> log.appendAsFollower(records, Integer.MAX_VALUE)
+            );
+
+            assertEquals(previousEndOffset, log.endOffset().offset());
+        }
+    }
+
+    @Test
+    void testInvalidLeaderEpoch() {
+        var previousEndOffset = log.endOffset().offset();
+        var epoch = log.lastFetchedEpoch() + 1;
+        var numberOfRecords = 10;
+
+        MemoryRecords batchWithValidEpoch = MemoryRecords.withRecords(
+            previousEndOffset,
+            Compression.NONE,
+            epoch,
+            IntStream
+                .range(0, numberOfRecords)
+                .mapToObj(number -> new 
SimpleRecord(Integer.toString(number).getBytes()))
+                .toArray(SimpleRecord[]::new)
+        );
+
+        MemoryRecords batchWithInvalidEpoch = MemoryRecords.withRecords(
+            previousEndOffset + numberOfRecords,
+            Compression.NONE,
+            epoch + 1,
+            IntStream
+                .range(0, numberOfRecords)
+                .mapToObj(number -> new 
SimpleRecord(Integer.toString(number).getBytes()))
+                .toArray(SimpleRecord[]::new)
+        );
+
+        var buffer = ByteBuffer.allocate(batchWithValidEpoch.sizeInBytes() + 
batchWithInvalidEpoch.sizeInBytes());
+        buffer.put(batchWithValidEpoch.buffer());
+        buffer.put(batchWithInvalidEpoch.buffer());
+        buffer.flip();
+
+        var records = MemoryRecords.readableRecords(buffer);
+
+        log.appendAsFollower(records, epoch);
+
+        // Check that only the first batch was appended
+        assertEquals(previousEndOffset + numberOfRecords, 
log.endOffset().offset());
+        // Check that the last fetched epoch matches the first batch
+        assertEquals(epoch, log.lastFetchedEpoch());
+    }
+
     @Test
     public void testReadOutOfRangeOffset() {
         final long initialOffset = 5L;
@@ -383,12 +478,19 @@ public class MockLogTest {
         }
         log.truncateToLatestSnapshot();
 
-        log.appendAsFollower(MemoryRecords.withRecords(initialOffset, 
Compression.NONE, epoch, recordFoo));
+        log.appendAsFollower(
+            MemoryRecords.withRecords(initialOffset, Compression.NONE, epoch, 
recordFoo),
+            epoch
+        );
 
-        assertThrows(OffsetOutOfRangeException.class, () -> 
log.read(log.startOffset() - 1,
-            Isolation.UNCOMMITTED));
-        assertThrows(OffsetOutOfRangeException.class, () -> 
log.read(log.endOffset().offset() + 1,
-            Isolation.UNCOMMITTED));
+        assertThrows(
+            OffsetOutOfRangeException.class,
+            () -> log.read(log.startOffset() - 1, Isolation.UNCOMMITTED)
+        );
+        assertThrows(
+            OffsetOutOfRangeException.class,
+            () -> log.read(log.endOffset().offset() + 1, Isolation.UNCOMMITTED)
+        );
     }
 
     @Test
@@ -958,6 +1060,7 @@ public class MockLogTest {
             MemoryRecords.withRecords(
                 log.endOffset().offset(),
                 Compression.NONE,
+                epoch,
                 records.toArray(new SimpleRecord[records.size()])
             ),
             epoch


Reply via email to