This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 8efdbce523 KAFKA-13837; Return an error from Fetch if follower is not 
a valid replica (#12150)
8efdbce523 is described below

commit 8efdbce5231f3b5ef61deb827c41b0a8c50aa84a
Author: Jason Gustafson <ja...@confluent.io>
AuthorDate: Wed May 18 20:58:20 2022 -0700

    KAFKA-13837; Return an error from Fetch if follower is not a valid replica 
(#12150)
    
    When a partition leader receives a `Fetch` request from a replica which is 
not in the current replica set, the behavior today is to return a successful 
fetch response, but with empty data. This causes the follower to retry until 
metadata converges without updating any state on the leader side. It is clearer 
in this case to return an error, so that the metadata inconsistency is visible 
in logging and so that the follower backs off before retrying.
    
    In this patch, we use `UNKNOWN_LEADER_EPOCH` when the `Fetch` request 
includes the current leader epoch. The way we see this is that the leader is 
validating the (replicaId, leaderEpoch) tuple. When the leader returns 
`UNKNOWN_LEADER_EPOCH`, it means that the leader does not expect the given 
leaderEpoch from that replica. If the request does not include a leader epoch, 
then we use `NOT_LEADER_OR_FOLLOWER`. We can take a similar interpretation for 
this case: the leader is rejecting the [...]
    
    As a part of this patch, I have refactored the way that the leader updates 
follower fetch state. Previously, the process is a little convoluted. We send 
the fetch from `ReplicaManager` down to `Partition.readRecords`, then we 
iterate over the results and call `Partition.updateFollowerFetchState`. It is 
more straightforward to update state directly as a part of `readRecords`. All 
we need to do is pass through the `FetchParams`. This also prevents an 
unnecessary copy of the read results.
    
    Reviewers: David Jacot <dja...@confluent.io>
---
 core/src/main/scala/kafka/cluster/Partition.scala  | 234 +++++++----
 core/src/main/scala/kafka/log/UnifiedLog.scala     |   2 +-
 .../src/main/scala/kafka/server/DelayedFetch.scala |  17 +-
 .../main/scala/kafka/server/FetchDataInfo.scala    |  13 +-
 .../main/scala/kafka/server/ReplicaManager.scala   | 116 ++----
 .../kafka/server/DelayedFetchTest.scala            |  28 +-
 .../unit/kafka/cluster/AbstractPartitionTest.scala |   5 +-
 .../unit/kafka/cluster/PartitionLockTest.scala     |  90 +++--
 .../scala/unit/kafka/cluster/PartitionTest.scala   | 435 +++++++++++++--------
 .../FetchRequestDownConversionConfigTest.scala     | 145 ++++---
 .../kafka/server/ReplicaManagerQuotasTest.scala    |  68 +---
 .../unit/kafka/server/ReplicaManagerTest.scala     |  32 +-
 .../test/scala/unit/kafka/utils/TestUtils.scala    |  74 ++--
 .../UpdateFollowerFetchStateBenchmark.java         |  13 +-
 14 files changed, 760 insertions(+), 512 deletions(-)

diff --git a/core/src/main/scala/kafka/cluster/Partition.scala 
b/core/src/main/scala/kafka/cluster/Partition.scala
index 9864480e78..61d5f707dc 100755
--- a/core/src/main/scala/kafka/cluster/Partition.scala
+++ b/core/src/main/scala/kafka/cluster/Partition.scala
@@ -440,8 +440,10 @@ class Partition(val topicPartition: TopicPartition,
     leaderReplicaIdOpt.filter(_ == localBrokerId)
   }
 
-  private def localLogWithEpochOrException(currentLeaderEpoch: 
Optional[Integer],
-                                           requireLeader: Boolean): UnifiedLog 
= {
+  private def localLogWithEpochOrThrow(
+    currentLeaderEpoch: Optional[Integer],
+    requireLeader: Boolean
+  ): UnifiedLog = {
     getLocalLog(currentLeaderEpoch, requireLeader) match {
       case Left(localLog) => localLog
       case Right(error) =>
@@ -719,55 +721,51 @@ class Partition(val topicPartition: TopicPartition,
    * Update the follower's state in the leader based on the last fetch 
request. See
    * [[Replica.updateFetchState()]] for details.
    *
-   * @return true if the follower's fetch state was updated, false if the 
followerId is not recognized
+   * This method is visible for performance testing (see 
`UpdateFollowerFetchStateBenchmark`)
    */
-  def updateFollowerFetchState(followerId: Int,
-                               followerFetchOffsetMetadata: LogOffsetMetadata,
-                               followerStartOffset: Long,
-                               followerFetchTimeMs: Long,
-                               leaderEndOffset: Long): Boolean = {
-    getReplica(followerId) match {
-      case Some(followerReplica) =>
-        // No need to calculate low watermark if there is no delayed 
DeleteRecordsRequest
-        val oldLeaderLW = if (delayedOperations.numDelayedDelete > 0) 
lowWatermarkIfLeader else -1L
-        val prevFollowerEndOffset = followerReplica.stateSnapshot.logEndOffset
-        followerReplica.updateFetchState(
-          followerFetchOffsetMetadata,
-          followerStartOffset,
-          followerFetchTimeMs,
-          leaderEndOffset)
-
-        val newLeaderLW = if (delayedOperations.numDelayedDelete > 0) 
lowWatermarkIfLeader else -1L
-        // check if the LW of the partition has incremented
-        // since the replica's logStartOffset may have incremented
-        val leaderLWIncremented = newLeaderLW > oldLeaderLW
-
-        // Check if this in-sync replica needs to be added to the ISR.
-        maybeExpandIsr(followerReplica)
-
-        // check if the HW of the partition can now be incremented
-        // since the replica may already be in the ISR and its LEO has just 
incremented
-        val leaderHWIncremented = if (prevFollowerEndOffset != 
followerReplica.stateSnapshot.logEndOffset) {
-          // the leader log may be updated by ReplicaAlterLogDirsThread so the 
following method must be in lock of
-          // leaderIsrUpdateLock to prevent adding new hw to invalid log.
-          inReadLock(leaderIsrUpdateLock) {
-            leaderLogIfLocal.exists(leaderLog => 
maybeIncrementLeaderHW(leaderLog, followerFetchTimeMs))
-          }
-        } else {
-          false
-        }
-
-        // some delayed operations may be unblocked after HW or LW changed
-        if (leaderLWIncremented || leaderHWIncremented)
-          tryCompleteDelayedRequests()
+  def updateFollowerFetchState(
+    replica: Replica,
+    followerFetchOffsetMetadata: LogOffsetMetadata,
+    followerStartOffset: Long,
+    followerFetchTimeMs: Long,
+    leaderEndOffset: Long
+  ): Unit = {
+    // No need to calculate low watermark if there is no delayed 
DeleteRecordsRequest
+    val oldLeaderLW = if (delayedOperations.numDelayedDelete > 0) 
lowWatermarkIfLeader else -1L
+    val prevFollowerEndOffset = replica.stateSnapshot.logEndOffset
+    replica.updateFetchState(
+      followerFetchOffsetMetadata,
+      followerStartOffset,
+      followerFetchTimeMs,
+      leaderEndOffset
+    )
+
+    val newLeaderLW = if (delayedOperations.numDelayedDelete > 0) 
lowWatermarkIfLeader else -1L
+    // check if the LW of the partition has incremented
+    // since the replica's logStartOffset may have incremented
+    val leaderLWIncremented = newLeaderLW > oldLeaderLW
+
+    // Check if this in-sync replica needs to be added to the ISR.
+    maybeExpandIsr(replica)
+
+    // check if the HW of the partition can now be incremented
+    // since the replica may already be in the ISR and its LEO has just 
incremented
+    val leaderHWIncremented = if (prevFollowerEndOffset != 
replica.stateSnapshot.logEndOffset) {
+      // the leader log may be updated by ReplicaAlterLogDirsThread so the 
following method must be in lock of
+      // leaderIsrUpdateLock to prevent adding new hw to invalid log.
+      inReadLock(leaderIsrUpdateLock) {
+        leaderLogIfLocal.exists(leaderLog => maybeIncrementLeaderHW(leaderLog, 
followerFetchTimeMs))
+      }
+    } else {
+      false
+    }
 
-        debug(s"Recorded replica $followerId log end offset (LEO) position " +
-          s"${followerFetchOffsetMetadata.messageOffset} and log start offset 
$followerStartOffset.")
-        true
+    // some delayed operations may be unblocked after HW or LW changed
+    if (leaderLWIncremented || leaderHWIncremented)
+      tryCompleteDelayedRequests()
 
-      case None =>
-        false
-    }
+    debug(s"Recorded replica ${replica.brokerId} log end offset (LEO) position 
" +
+      s"${followerFetchOffsetMetadata.messageOffset} and log start offset 
$followerStartOffset.")
   }
 
   /**
@@ -1145,15 +1143,112 @@ class Partition(val topicPartition: TopicPartition,
     info.copy(leaderHwChange = if (leaderHWIncremented) 
LeaderHwChange.Increased else LeaderHwChange.Same)
   }
 
-  def readRecords(lastFetchedEpoch: Optional[Integer],
-                  fetchOffset: Long,
-                  currentLeaderEpoch: Optional[Integer],
-                  maxBytes: Int,
-                  fetchIsolation: FetchIsolation,
-                  fetchOnlyFromLeader: Boolean,
-                  minOneMessage: Boolean): LogReadInfo = 
inReadLock(leaderIsrUpdateLock) {
-    // decide whether to only fetch from leader
-    val localLog = localLogWithEpochOrException(currentLeaderEpoch, 
fetchOnlyFromLeader)
+  /**
+   * Fetch records from the partition.
+   *
+   * @param fetchParams parameters of the corresponding `Fetch` request
+   * @param fetchPartitionData partition-level parameters of the `Fetch` (e.g. 
the fetch offset)
+   * @param fetchTimeMs current time in milliseconds on the broker of this 
fetch request
+   * @param maxBytes the maximum bytes to return
+   * @param minOneMessage whether to ensure that at least one complete message 
is returned
+   * @param updateFetchState true if the Fetch should update replica state 
(only applies to follower fetches)
+   * @return [[LogReadInfo]] containing the fetched records or the diverging 
epoch if present
+   * @throws NotLeaderOrFollowerException if this node is not the current 
leader and [[FetchParams.fetchOnlyLeader]]
+   *                                      is enabled, or if this is a follower 
fetch with an older request version
+   *                                      and the replicaId is not recognized 
among the current valid replicas
+   * @throws FencedLeaderEpochException if the leader epoch in the `Fetch` 
request is lower than the current
+   *                                    leader epoch
+   * @throws UnknownLeaderEpochException if the leader epoch in the `Fetch` 
request is higher than the current
+   *                                     leader epoch, or if this is a 
follower fetch and the replicaId is not
+   *                                     recognized among the current valid 
replicas
+   * @throws OffsetOutOfRangeException if the fetch offset is smaller than the 
log start offset or larger than
+   *                                   the log end offset (or high watermark 
depending on [[FetchParams.isolation]]),
+   *                                   or if the end offset for the last 
fetched epoch in [[FetchRequest.PartitionData]]
+   *                                   cannot be determined from the local 
epoch cache (e.g. if it is larger than
+   *                                   any cached epoch value)
+   */
+  def fetchRecords(
+    fetchParams: FetchParams,
+    fetchPartitionData: FetchRequest.PartitionData,
+    fetchTimeMs: Long,
+    maxBytes: Int,
+    minOneMessage: Boolean,
+    updateFetchState: Boolean
+  ): LogReadInfo = {
+    def readFromLocalLog(): LogReadInfo = {
+      readRecords(
+        fetchPartitionData.lastFetchedEpoch,
+        fetchPartitionData.fetchOffset,
+        fetchPartitionData.currentLeaderEpoch,
+        maxBytes,
+        fetchParams.isolation,
+        minOneMessage,
+        fetchParams.fetchOnlyLeader
+      )
+    }
+
+    if (fetchParams.isFromFollower) {
+      // Check that the request is from a valid replica before doing the read
+      val replica = followerReplicaOrThrow(fetchParams.replicaId, 
fetchPartitionData)
+      val logReadInfo = readFromLocalLog()
+
+      if (updateFetchState && logReadInfo.divergingEpoch.isEmpty) {
+        updateFollowerFetchState(
+          replica,
+          followerFetchOffsetMetadata = 
logReadInfo.fetchedData.fetchOffsetMetadata,
+          followerStartOffset = fetchPartitionData.logStartOffset,
+          followerFetchTimeMs = fetchTimeMs,
+          leaderEndOffset = logReadInfo.logEndOffset
+        )
+      }
+
+      logReadInfo
+    } else {
+      readFromLocalLog()
+    }
+  }
+
+  private def followerReplicaOrThrow(
+    replicaId: Int,
+    fetchPartitionData: FetchRequest.PartitionData
+  ): Replica = {
+    getReplica(replicaId).getOrElse {
+      debug(s"Leader $localBrokerId failed to record follower $replicaId's 
position " +
+        s"${fetchPartitionData.fetchOffset}, and last sent high watermark 
since the replica is " +
+        s"not recognized to be one of the assigned replicas 
${assignmentState.replicas.mkString(",")} " +
+        s"for leader epoch $leaderEpoch with partition epoch $partitionEpoch")
+
+      val error = if (fetchPartitionData.currentLeaderEpoch.isPresent) {
+        // The leader epoch is present in the request and matches the local 
epoch, but
+        // the replica is not in the replica set. This case is possible in 
KRaft,
+        // for example, when new replicas are added as part of a reassignment.
+        // We return UNKNOWN_LEADER_EPOCH to signify that the tuple 
(replicaId, leaderEpoch)
+        // is not yet recognized as valid, which causes the follower to retry.
+        Errors.UNKNOWN_LEADER_EPOCH
+      } else {
+        // The request has no leader epoch, which means it is an older 
version. We cannot
+        // say if the follower's state is stale or the local state is. In this 
case, we
+        // return `NOT_LEADER_OR_FOLLOWER` for lack of a better error so that 
the follower
+        // will retry.
+        Errors.NOT_LEADER_OR_FOLLOWER
+      }
+
+      throw error.exception(s"Replica $replicaId is not recognized as a " +
+        s"valid replica of $topicPartition in leader epoch $leaderEpoch with " 
+
+        s"partition epoch $partitionEpoch")
+    }
+  }
+
+  private def readRecords(
+    lastFetchedEpoch: Optional[Integer],
+    fetchOffset: Long,
+    currentLeaderEpoch: Optional[Integer],
+    maxBytes: Int,
+    fetchIsolation: FetchIsolation,
+    minOneMessage: Boolean,
+    fetchOnlyFromLeader: Boolean
+  ): LogReadInfo = inReadLock(leaderIsrUpdateLock) {
+    val localLog = localLogWithEpochOrThrow(currentLeaderEpoch, 
fetchOnlyFromLeader)
 
     // Note we use the log end offset prior to the read. This ensures that any 
appends following
     // the fetch do not prevent a follower from coming into sync.
@@ -1181,18 +1276,12 @@ class Partition(val topicPartition: TopicPartition,
       }
 
       if (epochEndOffset.leaderEpoch < fetchEpoch || epochEndOffset.endOffset 
< fetchOffset) {
-        val emptyFetchData = FetchDataInfo(
-          fetchOffsetMetadata = LogOffsetMetadata(fetchOffset),
-          records = MemoryRecords.EMPTY,
-          abortedTransactions = None
-        )
-
         val divergingEpoch = new FetchResponseData.EpochEndOffset()
           .setEpoch(epochEndOffset.leaderEpoch)
           .setEndOffset(epochEndOffset.endOffset)
 
         return LogReadInfo(
-          fetchedData = emptyFetchData,
+          fetchedData = FetchDataInfo.empty(fetchOffset),
           divergingEpoch = Some(divergingEpoch),
           highWatermark = initialHighWatermark,
           logStartOffset = initialLogStartOffset,
@@ -1201,14 +1290,21 @@ class Partition(val topicPartition: TopicPartition,
       }
     }
 
-    val fetchedData = localLog.read(fetchOffset, maxBytes, fetchIsolation, 
minOneMessage)
+    val fetchedData = localLog.read(
+      fetchOffset,
+      maxBytes,
+      fetchIsolation,
+      minOneMessage
+    )
+
     LogReadInfo(
       fetchedData = fetchedData,
       divergingEpoch = None,
       highWatermark = initialHighWatermark,
       logStartOffset = initialLogStartOffset,
       logEndOffset = initialLogEndOffset,
-      lastStableOffset = initialLastStableOffset)
+      lastStableOffset = initialLastStableOffset
+    )
   }
 
   def fetchOffsetForTimestamp(timestamp: Long,
@@ -1216,7 +1312,7 @@ class Partition(val topicPartition: TopicPartition,
                               currentLeaderEpoch: Optional[Integer],
                               fetchOnlyFromLeader: Boolean): 
Option[TimestampAndOffset] = inReadLock(leaderIsrUpdateLock) {
     // decide whether to only fetch from leader
-    val localLog = localLogWithEpochOrException(currentLeaderEpoch, 
fetchOnlyFromLeader)
+    val localLog = localLogWithEpochOrThrow(currentLeaderEpoch, 
fetchOnlyFromLeader)
 
     val lastFetchableOffset = isolationLevel match {
       case Some(IsolationLevel.READ_COMMITTED) => localLog.lastStableOffset
@@ -1277,7 +1373,7 @@ class Partition(val topicPartition: TopicPartition,
   def fetchOffsetSnapshot(currentLeaderEpoch: Optional[Integer],
                           fetchOnlyFromLeader: Boolean): LogOffsetSnapshot = 
inReadLock(leaderIsrUpdateLock) {
     // decide whether to only fetch from leader
-    val localLog = localLogWithEpochOrException(currentLeaderEpoch, 
fetchOnlyFromLeader)
+    val localLog = localLogWithEpochOrThrow(currentLeaderEpoch, 
fetchOnlyFromLeader)
     localLog.fetchOffsetSnapshot
   }
 
@@ -1285,7 +1381,7 @@ class Partition(val topicPartition: TopicPartition,
                                      maxNumOffsets: Int,
                                      isFromConsumer: Boolean,
                                      fetchOnlyFromLeader: Boolean): Seq[Long] 
= inReadLock(leaderIsrUpdateLock) {
-    val localLog = localLogWithEpochOrException(Optional.empty(), 
fetchOnlyFromLeader)
+    val localLog = localLogWithEpochOrThrow(Optional.empty(), 
fetchOnlyFromLeader)
     val allOffsets = localLog.legacyFetchOffsetsBefore(timestamp, 
maxNumOffsets)
 
     if (!isFromConsumer) {
diff --git a/core/src/main/scala/kafka/log/UnifiedLog.scala 
b/core/src/main/scala/kafka/log/UnifiedLog.scala
index 99524385fb..ddd66eb160 100644
--- a/core/src/main/scala/kafka/log/UnifiedLog.scala
+++ b/core/src/main/scala/kafka/log/UnifiedLog.scala
@@ -147,7 +147,7 @@ case class LogOffsetSnapshot(logStartOffset: Long,
                              lastStableOffset: LogOffsetMetadata)
 
 /**
- * Another container which is used for lower level reads using  
[[kafka.cluster.Partition.readRecords()]].
+ * Another container which is used for lower level reads using  
[[kafka.cluster.Partition.fetchRecords()]].
  */
 case class LogReadInfo(fetchedData: FetchDataInfo,
                        divergingEpoch: 
Option[FetchResponseData.EpochEndOffset],
diff --git a/core/src/main/scala/kafka/server/DelayedFetch.scala 
b/core/src/main/scala/kafka/server/DelayedFetch.scala
index 3eb8eedf4c..55a15682b6 100644
--- a/core/src/main/scala/kafka/server/DelayedFetch.scala
+++ b/core/src/main/scala/kafka/server/DelayedFetch.scala
@@ -158,15 +158,16 @@ class DelayedFetch(
    * Upon completion, read whatever data is available and pass to the complete 
callback
    */
   override def onComplete(): Unit = {
+    val fetchInfos = fetchPartitionStatus.map { case (tp, status) =>
+      tp -> status.fetchInfo
+    }
+
     val logReadResults = replicaManager.readFromLocalLog(
-      replicaId = params.replicaId,
-      fetchOnlyFromLeader = params.fetchOnlyLeader,
-      fetchIsolation = params.isolation,
-      fetchMaxBytes = params.maxBytes,
-      hardMaxBytesLimit = params.hardMaxBytesLimit,
-      readPartitionInfo = fetchPartitionStatus.map { case (tp, status) => tp 
-> status.fetchInfo },
-      clientMetadata = params.clientMetadata,
-      quota = quota)
+      params,
+      fetchInfos,
+      quota,
+      readFromPurgatory = true
+    )
 
     val fetchPartitionData = logReadResults.map { case (tp, result) =>
       val isReassignmentFetch = params.isFromFollower &&
diff --git a/core/src/main/scala/kafka/server/FetchDataInfo.scala 
b/core/src/main/scala/kafka/server/FetchDataInfo.scala
index 82e8092c10..95b68c0839 100644
--- a/core/src/main/scala/kafka/server/FetchDataInfo.scala
+++ b/core/src/main/scala/kafka/server/FetchDataInfo.scala
@@ -20,7 +20,7 @@ package kafka.server
 import kafka.api.Request
 import org.apache.kafka.common.IsolationLevel
 import org.apache.kafka.common.message.FetchResponseData
-import org.apache.kafka.common.record.Records
+import org.apache.kafka.common.record.{MemoryRecords, Records}
 import org.apache.kafka.common.replica.ClientMetadata
 import org.apache.kafka.common.requests.FetchRequest
 
@@ -75,6 +75,17 @@ case class FetchParams(
   }
 }
 
+object FetchDataInfo {
+  def empty(fetchOffset: Long): FetchDataInfo = {
+    FetchDataInfo(
+      fetchOffsetMetadata = LogOffsetMetadata(fetchOffset),
+      records = MemoryRecords.EMPTY,
+      firstEntryIncomplete = false,
+      abortedTransactions = None
+    )
+  }
+}
+
 case class FetchDataInfo(
   fetchOffsetMetadata: LogOffsetMetadata,
   records: Records,
diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala 
b/core/src/main/scala/kafka/server/ReplicaManager.scala
index e84abbe5f4..190f80f36c 100644
--- a/core/src/main/scala/kafka/server/ReplicaManager.scala
+++ b/core/src/main/scala/kafka/server/ReplicaManager.scala
@@ -994,29 +994,14 @@ class ReplicaManager(val config: KafkaConfig,
     quota: ReplicaQuota,
     responseCallback: Seq[(TopicIdPartition, FetchPartitionData)] => Unit
   ): Unit = {
-    // Restrict fetching to leader if request is from follower or from a 
client with older version (no ClientMetadata)
-    def readFromLog(): Seq[(TopicIdPartition, LogReadResult)] = {
-      val result = readFromLocalLog(
-        replicaId = params.replicaId,
-        fetchOnlyFromLeader = params.fetchOnlyLeader,
-        fetchIsolation = params.isolation,
-        fetchMaxBytes = params.maxBytes,
-        hardMaxBytesLimit = params.hardMaxBytesLimit,
-        readPartitionInfo = fetchInfos,
-        quota = quota,
-        clientMetadata = params.clientMetadata)
-      if (params.isFromFollower) updateFollowerFetchState(params.replicaId, 
result)
-      else result
-    }
-
-    val logReadResults = readFromLog()
-
     // check if this fetch request can be satisfied right away
+    val logReadResults = readFromLocalLog(params, fetchInfos, quota, 
readFromPurgatory = false)
     var bytesReadable: Long = 0
     var errorReadingData = false
     var hasDivergingEpoch = false
     var hasPreferredReadReplica = false
     val logReadResultMap = new mutable.HashMap[TopicIdPartition, LogReadResult]
+
     logReadResults.foreach { case (topicIdPartition, logReadResult) =>
       
brokerTopicStats.topicStats(topicIdPartition.topicPartition.topic).totalFetchRequestRate.mark()
       brokerTopicStats.allTopicsStats.totalFetchRequestRate.mark()
@@ -1073,14 +1058,12 @@ class ReplicaManager(val config: KafkaConfig,
   /**
    * Read from multiple topic partitions at the given offset up to maxSize 
bytes
    */
-  def readFromLocalLog(replicaId: Int,
-                       fetchOnlyFromLeader: Boolean,
-                       fetchIsolation: FetchIsolation,
-                       fetchMaxBytes: Int,
-                       hardMaxBytesLimit: Boolean,
-                       readPartitionInfo: Seq[(TopicIdPartition, 
PartitionData)],
-                       quota: ReplicaQuota,
-                       clientMetadata: Option[ClientMetadata]): 
Seq[(TopicIdPartition, LogReadResult)] = {
+  def readFromLocalLog(
+    params: FetchParams,
+    readPartitionInfo: Seq[(TopicIdPartition, PartitionData)],
+    quota: ReplicaQuota,
+    readFromPurgatory: Boolean
+  ): Seq[(TopicIdPartition, LogReadResult)] = {
     val traceEnabled = isTraceEnabled
 
     def read(tp: TopicIdPartition, fetchInfo: PartitionData, limitBytes: Int, 
minOneMessage: Boolean): LogReadResult = {
@@ -1104,13 +1087,13 @@ class ReplicaManager(val config: KafkaConfig,
           throw new InconsistentTopicIdException("Topic ID in the fetch 
session did not match the topic ID in the log.")
 
         // If we are the leader, determine the preferred read-replica
-        val preferredReadReplica = clientMetadata.flatMap(
-          metadata => findPreferredReadReplica(partition, metadata, replicaId, 
fetchInfo.fetchOffset, fetchTimeMs))
+        val preferredReadReplica = params.clientMetadata.flatMap(
+          metadata => findPreferredReadReplica(partition, metadata, 
params.replicaId, fetchInfo.fetchOffset, fetchTimeMs))
 
         if (preferredReadReplica.isDefined) {
           replicaSelectorOpt.foreach { selector =>
             debug(s"Replica selector ${selector.getClass.getSimpleName} 
returned preferred replica " +
-              s"${preferredReadReplica.get} for $clientMetadata")
+              s"${preferredReadReplica.get} for ${params.clientMetadata}")
           }
           // If a preferred read-replica is set, skip the read
           val offsetSnapshot = 
partition.fetchOffsetSnapshot(fetchInfo.currentLeaderEpoch, fetchOnlyFromLeader 
= false)
@@ -1126,20 +1109,19 @@ class ReplicaManager(val config: KafkaConfig,
             exception = None)
         } else {
           // Try the read first, this tells us whether we need all of 
adjustedFetchSize for this partition
-          val readInfo: LogReadInfo = partition.readRecords(
-            lastFetchedEpoch = fetchInfo.lastFetchedEpoch,
-            fetchOffset = fetchInfo.fetchOffset,
-            currentLeaderEpoch = fetchInfo.currentLeaderEpoch,
+          val readInfo: LogReadInfo = partition.fetchRecords(
+            fetchParams = params,
+            fetchPartitionData = fetchInfo,
+            fetchTimeMs = fetchTimeMs,
             maxBytes = adjustedMaxBytes,
-            fetchIsolation = fetchIsolation,
-            fetchOnlyFromLeader = fetchOnlyFromLeader,
-            minOneMessage = minOneMessage)
-          val isFromFollower = Request.isValidBrokerId(replicaId)
+            minOneMessage = minOneMessage,
+            updateFetchState = !readFromPurgatory
+          )
 
-          val fetchDataInfo = if (isFromFollower && 
shouldLeaderThrottle(quota, partition, replicaId)) {
+          val fetchDataInfo = if (params.isFromFollower && 
shouldLeaderThrottle(quota, partition, params.replicaId)) {
             // If the partition is being throttled, simply return an empty set.
             FetchDataInfo(readInfo.fetchedData.fetchOffsetMetadata, 
MemoryRecords.EMPTY)
-          } else if (!hardMaxBytesLimit && 
readInfo.fetchedData.firstEntryIncomplete) {
+          } else if (!params.hardMaxBytesLimit && 
readInfo.fetchedData.firstEntryIncomplete) {
             // For FetchRequest version 3, we replace incomplete message sets 
with an empty one as consumers can make
             // progress in such cases and don't need to report a 
`RecordTooLargeException`
             FetchDataInfo(readInfo.fetchedData.fetchOffsetMetadata, 
MemoryRecords.EMPTY)
@@ -1156,7 +1138,8 @@ class ReplicaManager(val config: KafkaConfig,
             fetchTimeMs = fetchTimeMs,
             lastStableOffset = Some(readInfo.lastStableOffset),
             preferredReadReplica = preferredReadReplica,
-            exception = None)
+            exception = None
+          )
         }
       } catch {
         // NOTE: Failed fetch requests metric is not incremented for known 
exceptions since it
@@ -1182,7 +1165,7 @@ class ReplicaManager(val config: KafkaConfig,
           brokerTopicStats.topicStats(tp.topic).failedFetchRequestRate.mark()
           brokerTopicStats.allTopicsStats.failedFetchRequestRate.mark()
 
-          val fetchSource = Request.describeReplicaId(replicaId)
+          val fetchSource = Request.describeReplicaId(params.replicaId)
           error(s"Error processing fetch with max size $adjustedMaxBytes from 
$fetchSource " +
             s"on partition $tp: $fetchInfo", e)
 
@@ -1194,13 +1177,14 @@ class ReplicaManager(val config: KafkaConfig,
             followerLogStartOffset = UnifiedLog.UnknownOffset,
             fetchTimeMs = -1L,
             lastStableOffset = None,
-            exception = Some(e))
+            exception = Some(e)
+          )
       }
     }
 
-    var limitBytes = fetchMaxBytes
+    var limitBytes = params.maxBytes
     val result = new mutable.ArrayBuffer[(TopicIdPartition, LogReadResult)]
-    var minOneMessage = !hardMaxBytesLimit
+    var minOneMessage = !params.hardMaxBytesLimit
     readPartitionInfo.foreach { case (tp, fetchInfo) =>
       val readResult = read(tp, fetchInfo, limitBytes, minOneMessage)
       val recordBatchSize = readResult.info.records.sizeInBytes
@@ -1802,52 +1786,6 @@ class ReplicaManager(val config: KafkaConfig,
     }
   }
 
-  /**
-   * Update the follower's fetch state on the leader based on the last fetch 
request and update `readResult`.
-   * If the follower replica is not recognized to be one of the assigned 
replicas, do not update
-   * `readResult` so that log start/end offset and high watermark is 
consistent with
-   * records in fetch response. Log start/end offset and high watermark may 
change not only due to
-   * this fetch request, e.g., rolling new log segment and removing old log 
segment may move log
-   * start offset further than the last offset in the fetched records. The 
followers will get the
-   * updated leader's state in the next fetch response. If follower has a 
diverging epoch or if read
-   * fails with any error, follower fetch state is not updated.
-   */
-  private def updateFollowerFetchState(followerId: Int,
-                                       readResults: Seq[(TopicIdPartition, 
LogReadResult)]): Seq[(TopicIdPartition, LogReadResult)] = {
-    readResults.map { case (topicIdPartition, readResult) =>
-      val updatedReadResult = if (readResult.error != Errors.NONE) {
-        debug(s"Skipping update of fetch state for follower $followerId since 
the " +
-          s"log read returned error ${readResult.error}")
-        readResult
-      } else if (readResult.divergingEpoch.nonEmpty) {
-        debug(s"Skipping update of fetch state for follower $followerId since 
the " +
-          s"log read returned diverging epoch ${readResult.divergingEpoch}")
-        readResult
-      } else {
-        onlinePartition(topicIdPartition.topicPartition) match {
-          case Some(partition) =>
-            if (partition.updateFollowerFetchState(followerId,
-              followerFetchOffsetMetadata = 
readResult.info.fetchOffsetMetadata,
-              followerStartOffset = readResult.followerLogStartOffset,
-              followerFetchTimeMs = readResult.fetchTimeMs,
-              leaderEndOffset = readResult.leaderLogEndOffset)) {
-              readResult
-            } else {
-              warn(s"Leader $localBrokerId failed to record follower 
$followerId's position " +
-                s"${readResult.info.fetchOffsetMetadata.messageOffset}, and 
last sent HW since the replica " +
-                s"is not recognized to be one of the assigned replicas 
${partition.assignmentState.replicas.mkString(",")} " +
-                s"for partition $topicIdPartition. Empty records will be 
returned for this partition.")
-              readResult.withEmptyFetchInfo
-            }
-          case None =>
-            warn(s"While recording the replica LEO, the partition 
$topicIdPartition hasn't been created.")
-            readResult
-        }
-      }
-      topicIdPartition -> updatedReadResult
-    }
-  }
-
   private def leaderPartitionsIterator: Iterator[Partition] =
     onlinePartitionsIterator.filter(_.leaderLogIfLocal.isDefined)
 
diff --git 
a/core/src/test/scala/integration/kafka/server/DelayedFetchTest.scala 
b/core/src/test/scala/integration/kafka/server/DelayedFetchTest.scala
index 940968f411..dce5a2eaee 100644
--- a/core/src/test/scala/integration/kafka/server/DelayedFetchTest.scala
+++ b/core/src/test/scala/integration/kafka/server/DelayedFetchTest.scala
@@ -73,7 +73,7 @@ class DelayedFetchTest {
         .thenThrow(new FencedLeaderEpochException("Requested epoch has been 
fenced"))
     when(replicaManager.isAddingReplica(any(), anyInt())).thenReturn(false)
 
-    expectReadFromReplica(replicaId, topicIdPartition, fetchStatus.fetchInfo, 
Errors.FENCED_LEADER_EPOCH)
+    expectReadFromReplica(fetchParams, topicIdPartition, 
fetchStatus.fetchInfo, Errors.FENCED_LEADER_EPOCH)
 
     assertTrue(delayedFetch.tryComplete())
     assertTrue(delayedFetch.isCompleted)
@@ -111,7 +111,7 @@ class DelayedFetchTest {
 
     
when(replicaManager.getPartitionOrException(topicIdPartition.topicPartition))
       .thenThrow(new NotLeaderOrFollowerException(s"Replica for 
$topicIdPartition not available"))
-    expectReadFromReplica(replicaId, topicIdPartition, fetchStatus.fetchInfo, 
Errors.NOT_LEADER_OR_FOLLOWER)
+    expectReadFromReplica(fetchParams, topicIdPartition, 
fetchStatus.fetchInfo, Errors.NOT_LEADER_OR_FOLLOWER)
     when(replicaManager.isAddingReplica(any(), anyInt())).thenReturn(false)
 
     assertTrue(delayedFetch.tryComplete())
@@ -160,7 +160,7 @@ class DelayedFetchTest {
         .setLeaderEpoch(lastFetchedEpoch.get)
         .setEndOffset(fetchOffset - 1))
     when(replicaManager.isAddingReplica(any(), anyInt())).thenReturn(false)
-    expectReadFromReplica(replicaId, topicIdPartition, fetchStatus.fetchInfo, 
Errors.NONE)
+    expectReadFromReplica(fetchParams, topicIdPartition, 
fetchStatus.fetchInfo, Errors.NONE)
 
     assertTrue(delayedFetch.tryComplete())
     assertTrue(delayedFetch.isCompleted)
@@ -182,20 +182,18 @@ class DelayedFetchTest {
     )
   }
 
-  private def expectReadFromReplica(replicaId: Int,
-                                    topicIdPartition: TopicIdPartition,
-                                    fetchPartitionData: 
FetchRequest.PartitionData,
-                                    error: Errors): Unit = {
+  private def expectReadFromReplica(
+    fetchParams: FetchParams,
+    topicIdPartition: TopicIdPartition,
+    fetchPartitionData: FetchRequest.PartitionData,
+    error: Errors
+  ): Unit = {
     when(replicaManager.readFromLocalLog(
-      replicaId = replicaId,
-      fetchOnlyFromLeader = true,
-      fetchIsolation = FetchLogEnd,
-      fetchMaxBytes = maxBytes,
-      hardMaxBytesLimit = false,
+      fetchParams,
       readPartitionInfo = Seq((topicIdPartition, fetchPartitionData)),
-      clientMetadata = None,
-      quota = replicaQuota))
-      .thenReturn(Seq((topicIdPartition, buildReadResult(error))))
+      quota = replicaQuota,
+      readFromPurgatory = true
+    )).thenReturn(Seq((topicIdPartition, buildReadResult(error))))
   }
 
   private def buildReadResult(error: Errors): LogReadResult = {
diff --git a/core/src/test/scala/unit/kafka/cluster/AbstractPartitionTest.scala 
b/core/src/test/scala/unit/kafka/cluster/AbstractPartitionTest.scala
index 969f8a2e79..147743a77d 100644
--- a/core/src/test/scala/unit/kafka/cluster/AbstractPartitionTest.scala
+++ b/core/src/test/scala/unit/kafka/cluster/AbstractPartitionTest.scala
@@ -43,6 +43,7 @@ object AbstractPartitionTest {
 class AbstractPartitionTest {
 
   val brokerId = AbstractPartitionTest.brokerId
+  val remoteReplicaId = brokerId + 1
   val topicPartition = new TopicPartition("test-topic", 0)
   val time = new MockTime()
   var tmpDir: File = _
@@ -115,7 +116,7 @@ class AbstractPartitionTest {
     partition.createLogIfNotExists(isNew = false, isFutureReplica = false, 
offsetCheckpoints, None)
 
     val controllerEpoch = 0
-    val replicas = List[Integer](brokerId, brokerId + 1).asJava
+    val replicas = List[Integer](brokerId, remoteReplicaId).asJava
     val isr = replicas
 
     if (isLeader) {
@@ -131,7 +132,7 @@ class AbstractPartitionTest {
     } else {
       assertTrue(partition.makeFollower(new LeaderAndIsrPartitionState()
         .setControllerEpoch(controllerEpoch)
-        .setLeader(brokerId + 1)
+        .setLeader(remoteReplicaId)
         .setLeaderEpoch(leaderEpoch)
         .setIsr(isr)
         .setPartitionEpoch(1)
diff --git a/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala 
b/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala
index 39a2edb504..1cf66a9b4c 100644
--- a/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala
+++ b/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala
@@ -17,7 +17,7 @@
 
 package kafka.cluster
 
-import java.util.Properties
+import java.util.{Optional, Properties}
 import java.util.concurrent._
 import java.util.concurrent.atomic.AtomicBoolean
 
@@ -29,7 +29,9 @@ import kafka.server.epoch.LeaderEpochFileCache
 import kafka.server.metadata.MockConfigRepository
 import kafka.utils._
 import 
org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrPartitionState
+import org.apache.kafka.common.protocol.ApiKeys
 import org.apache.kafka.common.record.{MemoryRecords, SimpleRecord}
+import org.apache.kafka.common.requests.FetchRequest
 import org.apache.kafka.common.utils.Utils
 import org.apache.kafka.common.{TopicPartition, Uuid}
 import org.apache.kafka.server.common.MetadataVersion
@@ -61,7 +63,6 @@ class PartitionLockTest extends Logging {
   val executorService = Executors.newFixedThreadPool(numReplicaFetchers + 
numProducers + 1)
   val appendSemaphore = new Semaphore(0)
   val shrinkIsrSemaphore = new Semaphore(0)
-  val followerQueues = (0 until numReplicaFetchers).map(_ => new 
ArrayBlockingQueue[MemoryRecords](2))
 
   var logManager: LogManager = _
   var partition: Partition = _
@@ -181,14 +182,16 @@ class PartitionLockTest extends Logging {
    * Then release the permit for the final append and verify that all appends 
and follower updates complete.
    */
   private def concurrentProduceFetchWithReadLockOnly(): Unit = {
+    val leaderEpoch = partition.getLeaderEpoch
+
     val appendFutures = scheduleAppends()
-    val stateUpdateFutures = scheduleUpdateFollowers(numProducers * 
numRecordsPerProducer - 1)
+    val stateUpdateFutures = scheduleFollowerFetches(leaderEpoch, numRecords = 
numProducers * numRecordsPerProducer - 1)
 
     appendSemaphore.release(numProducers * numRecordsPerProducer - 1)
     stateUpdateFutures.foreach(_.get(15, TimeUnit.SECONDS))
 
     appendSemaphore.release(1)
-    scheduleUpdateFollowers(1).foreach(_.get(15, TimeUnit.SECONDS)) // just to 
make sure follower state update still works
+    scheduleFollowerFetches(leaderEpoch, numRecords = 1).foreach(_.get(15, 
TimeUnit.SECONDS)) // just to make sure follower state update still works
     appendFutures.foreach(_.get(15, TimeUnit.SECONDS))
   }
 
@@ -199,9 +202,10 @@ class PartitionLockTest extends Logging {
    * permits for all appends to complete before verifying state updates.
    */
   private def concurrentProduceFetchWithWriteLock(): Unit = {
+    val leaderEpoch = partition.getLeaderEpoch
 
     val appendFutures = scheduleAppends()
-    val stateUpdateFutures = scheduleUpdateFollowers(numProducers * 
numRecordsPerProducer)
+    val stateUpdateFutures = scheduleFollowerFetches(leaderEpoch, numRecords = 
numProducers * numRecordsPerProducer)
 
     assertFalse(stateUpdateFutures.exists(_.isDone))
     appendSemaphore.release(numProducers * numRecordsPerProducer)
@@ -216,7 +220,7 @@ class PartitionLockTest extends Logging {
     (0 until numProducers).map { _ =>
       executorService.submit((() => {
         try {
-          append(partition, numRecordsPerProducer, followerQueues)
+          append(partition, numRecordsPerProducer)
         } catch {
           case e: Throwable =>
             error("Exception during append", e)
@@ -226,11 +230,11 @@ class PartitionLockTest extends Logging {
     }
   }
 
-  private def scheduleUpdateFollowers(numRecords: Int): Seq[Future[_]] = {
+  private def scheduleFollowerFetches(leaderEpoch: Int, numRecords: Int): 
Seq[Future[_]] = {
     (1 to numReplicaFetchers).map { index =>
       executorService.submit((() => {
         try {
-          updateFollowerFetchState(partition, index, numRecords, 
followerQueues(index - 1))
+          fetchFollower(partition, index, leaderEpoch, numRecords)
         } catch {
           case e: Throwable =>
             error("Exception during updateFollowerFetchState", e)
@@ -352,30 +356,68 @@ class PartitionLockTest extends Logging {
     logProps
   }
 
-  private def append(partition: Partition, numRecords: Int, followerQueues: 
Seq[ArrayBlockingQueue[MemoryRecords]]): Unit = {
+  private def append(
+    partition: Partition,
+    numRecords: Int
+  ): Unit = {
     val requestLocal = RequestLocal.withThreadConfinedCaching
     (0 until numRecords).foreach { _ =>
       val batch = TestUtils.records(records = List(new 
SimpleRecord("k1".getBytes, "v1".getBytes),
         new SimpleRecord("k2".getBytes, "v2".getBytes)))
       partition.appendRecordsToLeader(batch, origin = AppendOrigin.Client, 
requiredAcks = 0, requestLocal)
-      followerQueues.foreach(_.put(batch))
     }
   }
 
-  private def updateFollowerFetchState(partition: Partition, followerId: Int, 
numRecords: Int, followerQueue: ArrayBlockingQueue[MemoryRecords]): Unit = {
-    (1 to numRecords).foreach { i =>
-      val batch = followerQueue.poll(15, TimeUnit.SECONDS)
-      if (batch == null)
-        throw new RuntimeException(s"Timed out waiting for next batch $i")
-      val batches = batch.batches.iterator.asScala.toList
-      assertEquals(1, batches.size)
-      val recordBatch = batches.head
-      partition.updateFollowerFetchState(
-        followerId,
-        followerFetchOffsetMetadata = LogOffsetMetadata(recordBatch.lastOffset 
+ 1),
-        followerStartOffset = 0L,
-        followerFetchTimeMs = mockTime.milliseconds(),
-        leaderEndOffset = partition.localLogOrException.logEndOffset)
+  private def fetchFollower(
+    partition: Partition,
+    followerId: Int,
+    leaderEpoch: Int,
+    numRecords: Int
+  ): Unit = {
+    val logStartOffset = 0L
+    var fetchOffset = 0L
+    var lastFetchedEpoch = Optional.empty[Integer]
+    val maxBytes = 1
+
+    while (fetchOffset < numRecords) {
+      val fetchParams = FetchParams(
+        requestVersion = ApiKeys.FETCH.latestVersion,
+        replicaId = followerId,
+        maxWaitMs = 0,
+        minBytes = 1,
+        maxBytes = maxBytes,
+        isolation = FetchLogEnd,
+        clientMetadata = None
+      )
+
+      val fetchPartitionData = new FetchRequest.PartitionData(
+        Uuid.ZERO_UUID,
+        fetchOffset,
+        logStartOffset,
+        maxBytes,
+        Optional.of(Int.box(leaderEpoch)),
+        lastFetchedEpoch
+      )
+
+      val logReadInfo = partition.fetchRecords(
+        fetchParams,
+        fetchPartitionData,
+        mockTime.milliseconds(),
+        maxBytes,
+        minOneMessage = true,
+        updateFetchState = true
+      )
+
+      assertTrue(logReadInfo.divergingEpoch.isEmpty)
+
+      val batches = logReadInfo.fetchedData.records.batches.asScala
+      if (batches.nonEmpty) {
+        assertEquals(1, batches.size)
+
+        val batch = batches.head
+        lastFetchedEpoch = Optional.of(Int.box(batch.partitionLeaderEpoch))
+        fetchOffset = batch.lastOffset + 1
+      }
     }
   }
 
diff --git a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala 
b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
index 7d45617fec..04d2b15c60 100644
--- a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
+++ b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
@@ -16,6 +16,7 @@
  */
 package kafka.cluster
 
+import java.net.InetAddress
 import com.yammer.metrics.core.Metric
 import kafka.common.UnexpectedAppendOffsetException
 import kafka.log.{Defaults => _, _}
@@ -24,13 +25,13 @@ import kafka.server.checkpoints.OffsetCheckpoints
 import kafka.server.epoch.EpochEntry
 import kafka.utils._
 import kafka.zk.KafkaZkClient
-import org.apache.kafka.common.errors.{ApiException, 
InconsistentTopicIdException, NotLeaderOrFollowerException, 
OffsetNotAvailableException, OffsetOutOfRangeException}
+import org.apache.kafka.common.errors.{ApiException, 
FencedLeaderEpochException, InconsistentTopicIdException, 
NotLeaderOrFollowerException, OffsetNotAvailableException, 
OffsetOutOfRangeException, UnknownLeaderEpochException}
 import org.apache.kafka.common.message.FetchResponseData
 import 
org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrPartitionState
-import org.apache.kafka.common.protocol.Errors
+import org.apache.kafka.common.protocol.{ApiKeys, Errors}
 import org.apache.kafka.common.record.FileRecords.TimestampAndOffset
 import org.apache.kafka.common.record._
-import org.apache.kafka.common.requests.ListOffsetsRequest
+import org.apache.kafka.common.requests.{FetchRequest, ListOffsetsRequest}
 import org.apache.kafka.common.utils.SystemTime
 import org.apache.kafka.common.{IsolationLevel, TopicPartition, Uuid}
 import org.apache.kafka.metadata.LeaderRecoveryState
@@ -45,13 +46,56 @@ import java.nio.ByteBuffer
 import java.util.Optional
 import java.util.concurrent.{CountDownLatch, Semaphore}
 import kafka.server.epoch.LeaderEpochFileCache
+import org.apache.kafka.common.network.ListenerName
+import org.apache.kafka.common.replica.ClientMetadata
+import org.apache.kafka.common.replica.ClientMetadata.DefaultClientMetadata
+import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol}
 import org.apache.kafka.server.common.MetadataVersion
 import org.apache.kafka.server.common.MetadataVersion.IBP_2_6_IV0
 import org.apache.kafka.server.metrics.KafkaYammerMetrics
 
+import scala.compat.java8.OptionConverters._
 import scala.jdk.CollectionConverters._
 
+object PartitionTest {
+  def followerFetchParams(
+    replicaId: Int,
+    maxWaitMs: Long = 0L,
+    minBytes: Int = 1,
+    maxBytes: Int = Int.MaxValue
+  ): FetchParams = {
+    FetchParams(
+      requestVersion = ApiKeys.FETCH.latestVersion,
+      replicaId = replicaId,
+      maxWaitMs = maxWaitMs,
+      minBytes = minBytes,
+      maxBytes = maxBytes,
+      isolation = FetchLogEnd,
+      clientMetadata = None
+    )
+  }
+
+  def consumerFetchParams(
+    maxWaitMs: Long = 0L,
+    minBytes: Int = 1,
+    maxBytes: Int = Int.MaxValue,
+    clientMetadata: Option[ClientMetadata] = None,
+    isolation: FetchIsolation = FetchHighWatermark
+  ): FetchParams = {
+    FetchParams(
+      requestVersion = ApiKeys.FETCH.latestVersion,
+      replicaId = FetchRequest.CONSUMER_REPLICA_ID,
+      maxWaitMs = maxWaitMs,
+      minBytes = minBytes,
+      maxBytes = maxBytes,
+      isolation = isolation,
+      clientMetadata = clientMetadata
+    )
+  }
+}
+
 class PartitionTest extends AbstractPartitionTest {
+  import PartitionTest._
 
   @Test
   def testLastFetchedOffsetValidation(): Unit = {
@@ -74,6 +118,7 @@ class PartitionTest extends AbstractPartitionTest {
     assertEquals(17L, log.logEndOffset)
 
     val leaderEpoch = 10
+    val logStartOffset = 0L
     val partition = setupPartitionWithMocks(leaderEpoch = leaderEpoch, 
isLeader = true)
 
     def epochEndOffset(epoch: Int, endOffset: Long): 
FetchResponseData.EpochEndOffset = {
@@ -83,14 +128,13 @@ class PartitionTest extends AbstractPartitionTest {
     }
 
     def read(lastFetchedEpoch: Int, fetchOffset: Long): LogReadInfo = {
-      partition.readRecords(
-        Optional.of(lastFetchedEpoch),
+      fetchFollower(
+        partition,
+        remoteReplicaId,
         fetchOffset,
-        currentLeaderEpoch = Optional.of(leaderEpoch),
-        maxBytes = Int.MaxValue,
-        fetchIsolation = FetchLogEnd,
-        fetchOnlyFromLeader = true,
-        minOneMessage = true
+        logStartOffset,
+        leaderEpoch = Some(leaderEpoch),
+        lastFetchedEpoch = Some(lastFetchedEpoch)
       )
     }
 
@@ -192,6 +236,84 @@ class PartitionTest extends AbstractPartitionTest {
     assertEquals(None, partition.futureLog)
   }
 
+  @Test
+  def testFetchFromUnrecognizedFollower(): Unit = {
+    val controllerEpoch = 3
+    val leader = brokerId
+    val validReplica = brokerId + 1
+    val addingReplica1 = brokerId + 2
+    val addingReplica2 = brokerId + 3
+    val replicas = List(leader, validReplica)
+    val isr = List[Integer](leader, validReplica).asJava
+    val leaderEpoch = 8
+    val partitionEpoch = 1
+
+    assertTrue(partition.makeLeader(new LeaderAndIsrPartitionState()
+      .setControllerEpoch(controllerEpoch)
+      .setLeader(leader)
+      .setLeaderEpoch(leaderEpoch)
+      .setIsr(isr)
+      .setPartitionEpoch(partitionEpoch)
+      .setReplicas(replicas.map(Int.box).asJava)
+      .setIsNew(true),
+      offsetCheckpoints, None
+    ))
+
+    assertThrows(classOf[UnknownLeaderEpochException], () => {
+      fetchFollower(
+        partition,
+        replicaId = addingReplica1,
+        fetchOffset = 0L,
+        leaderEpoch = Some(leaderEpoch)
+      )
+    })
+    assertEquals(None, 
partition.getReplica(addingReplica1).map(_.stateSnapshot.logEndOffset))
+
+    assertThrows(classOf[NotLeaderOrFollowerException], () => {
+      fetchFollower(
+        partition,
+        replicaId = addingReplica2,
+        fetchOffset = 0L,
+        leaderEpoch = None
+      )
+    })
+    assertEquals(None, 
partition.getReplica(addingReplica2).map(_.stateSnapshot.logEndOffset))
+
+    // The replicas are added as part of a reassignment
+    val newReplicas = List(leader, validReplica, addingReplica1, 
addingReplica2)
+    val newPartitionEpoch = partitionEpoch + 1
+    val addingReplicas = List(addingReplica1, addingReplica2)
+
+    assertFalse(partition.makeLeader(new LeaderAndIsrPartitionState()
+      .setControllerEpoch(controllerEpoch)
+      .setLeader(leader)
+      .setLeaderEpoch(leaderEpoch)
+      .setIsr(isr)
+      .setPartitionEpoch(newPartitionEpoch)
+      .setReplicas(newReplicas.map(Int.box).asJava)
+      .setAddingReplicas(addingReplicas.map(Int.box).asJava)
+      .setIsNew(true),
+      offsetCheckpoints, None
+    ))
+
+    // Now the fetches are allowed
+    assertEquals(0L, fetchFollower(
+      partition,
+      replicaId = addingReplica1,
+      fetchOffset = 0L,
+      leaderEpoch = Some(leaderEpoch)
+    ).logEndOffset)
+    assertEquals(Some(0L), 
partition.getReplica(addingReplica1).map(_.stateSnapshot.logEndOffset))
+
+    assertEquals(0L, fetchFollower(
+      partition,
+      replicaId = addingReplica2,
+      fetchOffset = 0L,
+      leaderEpoch = None
+    ).logEndOffset)
+    assertEquals(Some(0L), 
partition.getReplica(addingReplica2).map(_.stateSnapshot.logEndOffset))
+  }
+
   // Verify that partition.makeFollower() and 
partition.appendRecordsToFollowerOrFutureReplica() can run concurrently
   @Test
   def testMakeFollowerWithWithFollowerAppendRecords(): Unit = {
@@ -405,69 +527,59 @@ class PartitionTest extends AbstractPartitionTest {
   }
 
   @Test
-  def testReadRecordEpochValidationForLeader(): Unit = {
+  def testLeaderEpochValidationOnLeader(): Unit = {
     val leaderEpoch = 5
     val partition = setupPartitionWithMocks(leaderEpoch, isLeader = true)
 
-    def assertReadRecordsError(error: Errors,
-                               currentLeaderEpochOpt: Optional[Integer]): Unit 
= {
-      try {
-        partition.readRecords(
-          lastFetchedEpoch = Optional.empty(),
-          fetchOffset = 0L,
-          currentLeaderEpoch = currentLeaderEpochOpt,
-          maxBytes = 1024,
-          fetchIsolation = FetchLogEnd,
-          fetchOnlyFromLeader = true,
-          minOneMessage = false)
-        if (error != Errors.NONE)
-          fail(s"Expected readRecords to fail with error $error")
-      } catch {
-        case e: Exception =>
-          assertEquals(error, Errors.forException(e))
-      }
+    def sendFetch(leaderEpoch: Option[Int]): LogReadInfo = {
+      fetchFollower(
+        partition,
+        remoteReplicaId,
+        fetchOffset = 0L,
+        leaderEpoch = leaderEpoch
+      )
     }
 
-    assertReadRecordsError(Errors.NONE, Optional.empty())
-    assertReadRecordsError(Errors.NONE, Optional.of(leaderEpoch))
-    assertReadRecordsError(Errors.FENCED_LEADER_EPOCH, Optional.of(leaderEpoch 
- 1))
-    assertReadRecordsError(Errors.UNKNOWN_LEADER_EPOCH, 
Optional.of(leaderEpoch + 1))
+    assertEquals(0L, sendFetch(leaderEpoch = None).logEndOffset)
+    assertEquals(0L, sendFetch(leaderEpoch = Some(leaderEpoch)).logEndOffset)
+    assertThrows(classOf[FencedLeaderEpochException], () => 
sendFetch(Some(leaderEpoch - 1)))
+    assertThrows(classOf[UnknownLeaderEpochException], () => 
sendFetch(Some(leaderEpoch + 1)))
   }
 
   @Test
-  def testReadRecordEpochValidationForFollower(): Unit = {
+  def testLeaderEpochValidationOnFollower(): Unit = {
     val leaderEpoch = 5
     val partition = setupPartitionWithMocks(leaderEpoch, isLeader = false)
 
-    def assertReadRecordsError(error: Errors,
-                               currentLeaderEpochOpt: Optional[Integer],
-                               fetchOnlyLeader: Boolean): Unit = {
-      try {
-        partition.readRecords(
-          lastFetchedEpoch = Optional.empty(),
-          fetchOffset = 0L,
-          currentLeaderEpoch = currentLeaderEpochOpt,
-          maxBytes = 1024,
-          fetchIsolation = FetchLogEnd,
-          fetchOnlyFromLeader = fetchOnlyLeader,
-          minOneMessage = false)
-        if (error != Errors.NONE)
-          fail(s"Expected readRecords to fail with error $error")
-      } catch {
-        case e: Exception =>
-          assertEquals(error, Errors.forException(e))
-      }
+    def sendFetch(
+      leaderEpoch: Option[Int],
+      clientMetadata: Option[ClientMetadata]
+    ): LogReadInfo = {
+      fetchConsumer(
+        partition,
+        fetchOffset = 0L,
+        leaderEpoch = leaderEpoch,
+        clientMetadata = clientMetadata
+      )
     }
 
-    assertReadRecordsError(Errors.NONE, Optional.empty(), fetchOnlyLeader = 
false)
-    assertReadRecordsError(Errors.NONE, Optional.of(leaderEpoch), 
fetchOnlyLeader = false)
-    assertReadRecordsError(Errors.FENCED_LEADER_EPOCH, Optional.of(leaderEpoch 
- 1), fetchOnlyLeader = false)
-    assertReadRecordsError(Errors.UNKNOWN_LEADER_EPOCH, 
Optional.of(leaderEpoch + 1), fetchOnlyLeader = false)
-
-    assertReadRecordsError(Errors.NOT_LEADER_OR_FOLLOWER, Optional.empty(), 
fetchOnlyLeader = true)
-    assertReadRecordsError(Errors.NOT_LEADER_OR_FOLLOWER, 
Optional.of(leaderEpoch), fetchOnlyLeader = true)
-    assertReadRecordsError(Errors.FENCED_LEADER_EPOCH, Optional.of(leaderEpoch 
- 1), fetchOnlyLeader = true)
-    assertReadRecordsError(Errors.UNKNOWN_LEADER_EPOCH, 
Optional.of(leaderEpoch + 1), fetchOnlyLeader = true)
+    // Follower fetching is only allowed when the client provides metadata
+    assertThrows(classOf[NotLeaderOrFollowerException], () => sendFetch(None, 
None))
+    assertThrows(classOf[NotLeaderOrFollowerException], () => 
sendFetch(Some(leaderEpoch), None))
+    assertThrows(classOf[FencedLeaderEpochException], () => 
sendFetch(Some(leaderEpoch - 1), None))
+    assertThrows(classOf[UnknownLeaderEpochException], () => 
sendFetch(Some(leaderEpoch + 1), None))
+
+    val clientMetadata = new DefaultClientMetadata(
+      "rack",
+      "clientId",
+      InetAddress.getLoopbackAddress,
+      KafkaPrincipal.ANONYMOUS,
+      ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT).value
+    )
+    assertEquals(0L, sendFetch(leaderEpoch = None, 
Some(clientMetadata)).logEndOffset)
+    assertEquals(0L, sendFetch(leaderEpoch = Some(leaderEpoch), 
Some(clientMetadata)).logEndOffset)
+    assertThrows(classOf[FencedLeaderEpochException], () => 
sendFetch(Some(leaderEpoch - 1), Some(clientMetadata)))
+    assertThrows(classOf[UnknownLeaderEpochException], () => 
sendFetch(Some(leaderEpoch + 1), Some(clientMetadata)))
   }
 
   @Test
@@ -588,16 +700,6 @@ class PartitionTest extends AbstractPartitionTest {
     assertEquals(partition.localLogOrException.logStartOffset, 
partition.localLogOrException.highWatermark,
       "Expected leader's HW not move")
 
-    // let the follower in ISR move leader's HW to move further but below LEO
-    def updateFollowerFetchState(followerId: Int, fetchOffsetMetadata: 
LogOffsetMetadata): Unit = {
-      partition.updateFollowerFetchState(
-        followerId,
-        followerFetchOffsetMetadata = fetchOffsetMetadata,
-        followerStartOffset = 0L,
-        followerFetchTimeMs = time.milliseconds(),
-        leaderEndOffset = partition.localLogOrException.logEndOffset)
-    }
-
     def fetchOffsetsForTimestamp(timestamp: Long, isolation: 
Option[IsolationLevel]): Either[ApiException, Option[TimestampAndOffset]] = {
       try {
         Right(partition.fetchOffsetForTimestamp(
@@ -611,11 +713,12 @@ class PartitionTest extends AbstractPartitionTest {
       }
     }
 
-    updateFollowerFetchState(follower1, LogOffsetMetadata(0))
-    updateFollowerFetchState(follower1, LogOffsetMetadata(2))
+    // let the follower in ISR move leader's HW to move further but below LEO
+    fetchFollower(partition, replicaId = follower1, fetchOffset = 0L)
+    fetchFollower(partition, replicaId = follower1, fetchOffset = 2L)
 
-    updateFollowerFetchState(follower2, LogOffsetMetadata(0))
-    updateFollowerFetchState(follower2, LogOffsetMetadata(2))
+    fetchFollower(partition, replicaId = follower2, fetchOffset = 0L)
+    fetchFollower(partition, replicaId = follower2, fetchOffset = 2L)
 
     // Simulate successful ISR update
     alterPartitionManager.completeIsrUpdate(2)
@@ -704,8 +807,8 @@ class PartitionTest extends AbstractPartitionTest {
     }
 
     // Next fetch from replicas, HW is moved up to 5 (ahead of the LEO)
-    updateFollowerFetchState(follower1, LogOffsetMetadata(5))
-    updateFollowerFetchState(follower2, LogOffsetMetadata(5))
+    fetchFollower(partition, replicaId = follower1, fetchOffset = 5L)
+    fetchFollower(partition, replicaId = follower2, fetchOffset = 5L)
 
     // Simulate successful ISR update
     alterPartitionManager.completeIsrUpdate(6)
@@ -919,17 +1022,8 @@ class PartitionTest extends AbstractPartitionTest {
     assertEquals(partition.localLogOrException.logStartOffset, 
partition.log.get.highWatermark, "Expected leader's HW not move")
 
     // let the follower in ISR move leader's HW to move further but below LEO
-    def updateFollowerFetchState(followerId: Int, fetchOffsetMetadata: 
LogOffsetMetadata): Unit = {
-      partition.updateFollowerFetchState(
-        followerId,
-        followerFetchOffsetMetadata = fetchOffsetMetadata,
-        followerStartOffset = 0L,
-        followerFetchTimeMs = time.milliseconds(),
-        leaderEndOffset = partition.localLogOrException.logEndOffset)
-    }
-
-    updateFollowerFetchState(follower2, LogOffsetMetadata(0))
-    updateFollowerFetchState(follower2, 
LogOffsetMetadata(lastOffsetOfFirstBatch))
+    fetchFollower(partition, replicaId = follower2, fetchOffset = 0)
+    fetchFollower(partition, replicaId = follower2, fetchOffset = 
lastOffsetOfFirstBatch)
     assertEquals(lastOffsetOfFirstBatch, partition.log.get.highWatermark, 
"Expected leader's HW")
 
     // current leader becomes follower and then leader again (without any new 
records appended)
@@ -959,13 +1053,13 @@ class PartitionTest extends AbstractPartitionTest {
     partition.appendRecordsToLeader(batch3, origin = AppendOrigin.Client, 
requiredAcks = 0, requestLocal)
 
     // fetch from follower not in ISR from log start offset should not add 
this follower to ISR
-    updateFollowerFetchState(follower1, LogOffsetMetadata(0))
-    updateFollowerFetchState(follower1, 
LogOffsetMetadata(lastOffsetOfFirstBatch))
+    fetchFollower(partition, replicaId = follower1, fetchOffset = 0)
+    fetchFollower(partition, replicaId = follower1, fetchOffset = 
lastOffsetOfFirstBatch)
     assertEquals(Set[Integer](leader, follower2), 
partition.partitionState.isr, "ISR")
 
     // fetch from the follower not in ISR from start offset of the current 
leader epoch should
     // add this follower to ISR
-    updateFollowerFetchState(follower1, 
LogOffsetMetadata(currentLeaderEpochStartOffset))
+    fetchFollower(partition, replicaId = follower1, fetchOffset = 
currentLeaderEpochStartOffset)
 
     // Expansion does not affect the ISR
     assertEquals(Set[Integer](leader, follower2), 
partition.partitionState.isr, "ISR")
@@ -1057,12 +1151,7 @@ class PartitionTest extends AbstractPartitionTest {
 
     time.sleep(500)
 
-    partition.updateFollowerFetchState(remoteBrokerId,
-      followerFetchOffsetMetadata = LogOffsetMetadata(3),
-      followerStartOffset = 0L,
-      followerFetchTimeMs = time.milliseconds(),
-      leaderEndOffset = 6L)
-
+    fetchFollower(partition, replicaId = remoteBrokerId, fetchOffset = 3L)
     assertReplicaState(partition, remoteBrokerId,
       lastCaughtUpTimeMs = initializeTimeMs,
       logStartOffset = 0L,
@@ -1071,12 +1160,7 @@ class PartitionTest extends AbstractPartitionTest {
 
     time.sleep(500)
 
-    partition.updateFollowerFetchState(remoteBrokerId,
-      followerFetchOffsetMetadata = LogOffsetMetadata(6L),
-      followerStartOffset = 0L,
-      followerFetchTimeMs = time.milliseconds(),
-      leaderEndOffset = 6L)
-
+    fetchFollower(partition, replicaId = remoteBrokerId, fetchOffset = 6L)
     assertReplicaState(partition, remoteBrokerId,
       lastCaughtUpTimeMs = time.milliseconds(),
       logStartOffset = 0L,
@@ -1114,11 +1198,7 @@ class PartitionTest extends AbstractPartitionTest {
       logEndOffset = UnifiedLog.UnknownOffset
     )
 
-    partition.updateFollowerFetchState(remoteBrokerId,
-      followerFetchOffsetMetadata = LogOffsetMetadata(10),
-      followerStartOffset = 0L,
-      followerFetchTimeMs = time.milliseconds(),
-      leaderEndOffset = 10L)
+    fetchFollower(partition, replicaId = remoteBrokerId, fetchOffset = 10L)
 
     // Check that the isr didn't change and alter update is scheduled
     assertEquals(Set(brokerId), partition.inSyncReplicaIds)
@@ -1169,12 +1249,7 @@ class PartitionTest extends AbstractPartitionTest {
       logEndOffset = UnifiedLog.UnknownOffset
     )
 
-    partition.updateFollowerFetchState(remoteBrokerId,
-      followerFetchOffsetMetadata = LogOffsetMetadata(3),
-      followerStartOffset = 0L,
-      followerFetchTimeMs = time.milliseconds(),
-      leaderEndOffset = 6L)
-
+    fetchFollower(partition, replicaId = remoteBrokerId, fetchOffset = 3L)
     assertEquals(Set(brokerId), partition.partitionState.isr)
     assertReplicaState(partition, remoteBrokerId,
       lastCaughtUpTimeMs = 0L,
@@ -1182,12 +1257,7 @@ class PartitionTest extends AbstractPartitionTest {
       logEndOffset = 3L
     )
 
-    partition.updateFollowerFetchState(remoteBrokerId,
-      followerFetchOffsetMetadata = LogOffsetMetadata(10),
-      followerStartOffset = 0L,
-      followerFetchTimeMs = time.milliseconds(),
-      leaderEndOffset = 6L)
-
+    fetchFollower(partition, replicaId = remoteBrokerId, fetchOffset = 10L)
     assertEquals(alterPartitionManager.isrUpdates.size, 1)
     val isrItem = alterPartitionManager.isrUpdates.head
     assertEquals(isrItem.leaderAndIsr.isr, List(brokerId, remoteBrokerId))
@@ -1238,11 +1308,7 @@ class PartitionTest extends AbstractPartitionTest {
       logEndOffset = UnifiedLog.UnknownOffset
     )
 
-    partition.updateFollowerFetchState(remoteBrokerId,
-      followerFetchOffsetMetadata = LogOffsetMetadata(10),
-      followerStartOffset = 0L,
-      followerFetchTimeMs = time.milliseconds(),
-      leaderEndOffset = 10L)
+    fetchFollower(partition, replicaId = remoteBrokerId, fetchOffset = 10L)
 
     // Follower state is updated, but the ISR has not expanded
     assertEquals(Set(brokerId), partition.inSyncReplicaIds)
@@ -1465,11 +1531,7 @@ class PartitionTest extends AbstractPartitionTest {
     // There is a short delay before the first fetch. The follower is not yet 
caught up to the log end.
     time.sleep(5000)
     val firstFetchTimeMs = time.milliseconds()
-    partition.updateFollowerFetchState(remoteBrokerId,
-      followerFetchOffsetMetadata = LogOffsetMetadata(5),
-      followerStartOffset = 0L,
-      followerFetchTimeMs = firstFetchTimeMs,
-      leaderEndOffset = 10L)
+    fetchFollower(partition, replicaId = remoteBrokerId, fetchOffset = 5L, 
fetchTimeMs = firstFetchTimeMs)
     assertReplicaState(partition, remoteBrokerId,
       lastCaughtUpTimeMs = initializeTimeMs,
       logStartOffset = 0L,
@@ -1481,11 +1543,7 @@ class PartitionTest extends AbstractPartitionTest {
     // The total elapsed time from initialization is larger than the max 
allowed replica lag.
     time.sleep(5001)
     seedLogData(log, numRecords = 5, leaderEpoch = leaderEpoch)
-    partition.updateFollowerFetchState(remoteBrokerId,
-      followerFetchOffsetMetadata = LogOffsetMetadata(10),
-      followerStartOffset = 0L,
-      followerFetchTimeMs = time.milliseconds(),
-      leaderEndOffset = 15L)
+    fetchFollower(partition, replicaId = remoteBrokerId, fetchOffset = 10L, 
fetchTimeMs = time.milliseconds())
     assertReplicaState(partition, remoteBrokerId,
       lastCaughtUpTimeMs = firstFetchTimeMs,
       logStartOffset = 0L,
@@ -1530,11 +1588,7 @@ class PartitionTest extends AbstractPartitionTest {
     )
 
     // The follower catches up to the log end immediately.
-    partition.updateFollowerFetchState(remoteBrokerId,
-      followerFetchOffsetMetadata = LogOffsetMetadata(10),
-      followerStartOffset = 0L,
-      followerFetchTimeMs = time.milliseconds(),
-      leaderEndOffset = 10L)
+    fetchFollower(partition, replicaId = remoteBrokerId, fetchOffset = 10L)
     assertReplicaState(partition, remoteBrokerId,
       lastCaughtUpTimeMs = time.milliseconds(),
       logStartOffset = 0L,
@@ -1658,11 +1712,7 @@ class PartitionTest extends AbstractPartitionTest {
 
     // This will attempt to expand the ISR
     val firstFetchTimeMs = time.milliseconds()
-    partition.updateFollowerFetchState(remoteBrokerId,
-      followerFetchOffsetMetadata = LogOffsetMetadata(10),
-      followerStartOffset = 0L,
-      followerFetchTimeMs = time.milliseconds(),
-      leaderEndOffset = 10L)
+    fetchFollower(partition, replicaId = remoteBrokerId, fetchOffset = 10L, 
fetchTimeMs = firstFetchTimeMs)
 
     // Follower state is updated, but the ISR has not expanded
     assertEquals(Set(brokerId), partition.inSyncReplicaIds)
@@ -1706,13 +1756,7 @@ class PartitionTest extends AbstractPartitionTest {
     assertEquals(0L, partition.localLogOrException.highWatermark)
 
     // Expand ISR
-    partition.updateFollowerFetchState(
-      followerId = follower3,
-      followerFetchOffsetMetadata = LogOffsetMetadata(10),
-      followerStartOffset = 0L,
-      followerFetchTimeMs = time.milliseconds(),
-      leaderEndOffset = 10
-    )
+    fetchFollower(partition, replicaId = follower3, fetchOffset = 10L)
     assertEquals(Set(brokerId, follower1, follower2), 
partition.partitionState.isr)
     assertEquals(Set(brokerId, follower1, follower2, follower3), 
partition.partitionState.maximalIsr)
 
@@ -1776,13 +1820,7 @@ class PartitionTest extends AbstractPartitionTest {
     assertEquals(0L, partition.localLogOrException.highWatermark)
 
     // Expand ISR
-    partition.updateFollowerFetchState(
-      followerId = follower3,
-      followerFetchOffsetMetadata = LogOffsetMetadata(10),
-      followerStartOffset = 0L,
-      followerFetchTimeMs = time.milliseconds(),
-      leaderEndOffset = 10
-    )
+    fetchFollower(partition, replicaId = follower3, fetchOffset = 10L)
 
     // Try avoiding a race
     TestUtils.waitUntilTrue(() => !partition.partitionState.isInflight, 
"Expected ISR state to be committed", 100)
@@ -2170,14 +2208,7 @@ class PartitionTest extends AbstractPartitionTest {
     )
 
     // Follower fetches and updates its replica state.
-    partition.updateFollowerFetchState(
-      followerId = followerId,
-      followerFetchOffsetMetadata = LogOffsetMetadata(0L),
-      followerStartOffset = 0L,
-      followerFetchTimeMs = time.milliseconds(),
-      leaderEndOffset = partition.localLogOrException.logEndOffset
-    )
-
+    fetchFollower(partition, replicaId = followerId, fetchOffset = 0L)
     assertReplicaState(partition, followerId,
       lastCaughtUpTimeMs = time.milliseconds(),
       logStartOffset = 0L,
@@ -2473,4 +2504,76 @@ class PartitionTest extends AbstractPartitionTest {
         fail(s"Replica $replicaId not found.")
     }
   }
+
+  private def fetchConsumer(
+    partition: Partition,
+    fetchOffset: Long,
+    leaderEpoch: Option[Int],
+    clientMetadata: Option[ClientMetadata],
+    maxBytes: Int = Int.MaxValue,
+    lastFetchedEpoch: Option[Int] = None,
+    fetchTimeMs: Long = time.milliseconds(),
+    topicId: Uuid = Uuid.ZERO_UUID,
+    isolation: FetchIsolation = FetchHighWatermark
+  ): LogReadInfo = {
+    val fetchParams = consumerFetchParams(
+      maxBytes = maxBytes,
+      clientMetadata = clientMetadata,
+      isolation = isolation
+    )
+
+    val fetchPartitionData = new FetchRequest.PartitionData(
+      topicId,
+      fetchOffset,
+      FetchRequest.INVALID_LOG_START_OFFSET,
+      maxBytes,
+      leaderEpoch.map(Int.box).asJava,
+      lastFetchedEpoch.map(Int.box).asJava
+    )
+
+    partition.fetchRecords(
+      fetchParams,
+      fetchPartitionData,
+      fetchTimeMs,
+      maxBytes,
+      minOneMessage = true,
+      updateFetchState = false
+    )
+  }
+
+  private def fetchFollower(
+    partition: Partition,
+    replicaId: Int,
+    fetchOffset: Long,
+    logStartOffset: Long = 0L,
+    maxBytes: Int = Int.MaxValue,
+    leaderEpoch: Option[Int] = None,
+    lastFetchedEpoch: Option[Int] = None,
+    fetchTimeMs: Long = time.milliseconds(),
+    topicId: Uuid = Uuid.ZERO_UUID
+  ): LogReadInfo = {
+    val fetchParams = followerFetchParams(
+      replicaId,
+      maxBytes = maxBytes
+    )
+
+    val fetchPartitionData = new FetchRequest.PartitionData(
+      topicId,
+      fetchOffset,
+      logStartOffset,
+      maxBytes,
+      leaderEpoch.map(Int.box).asJava,
+      lastFetchedEpoch.map(Int.box).asJava
+    )
+
+    partition.fetchRecords(
+      fetchParams,
+      fetchPartitionData,
+      fetchTimeMs,
+      maxBytes,
+      minOneMessage = true,
+      updateFetchState = true
+    )
+  }
+
 }
diff --git 
a/core/src/test/scala/unit/kafka/server/FetchRequestDownConversionConfigTest.scala
 
b/core/src/test/scala/unit/kafka/server/FetchRequestDownConversionConfigTest.scala
index 6efa37b117..0cf7c1d8e2 100644
--- 
a/core/src/test/scala/unit/kafka/server/FetchRequestDownConversionConfigTest.scala
+++ 
b/core/src/test/scala/unit/kafka/server/FetchRequestDownConversionConfigTest.scala
@@ -18,21 +18,25 @@ package kafka.server
 
 import java.util
 import java.util.{Optional, Properties}
+
 import kafka.log.LogConfig
-import kafka.utils.TestUtils
+import kafka.utils.{TestInfoUtils, TestUtils}
 import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord}
+import org.apache.kafka.common.message.FetchResponseData
 import org.apache.kafka.common.{TopicPartition, Uuid}
 import org.apache.kafka.common.protocol.{ApiKeys, Errors}
 import org.apache.kafka.common.requests.{FetchRequest, FetchResponse}
 import org.apache.kafka.common.serialization.StringSerializer
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo}
+import org.junit.jupiter.params.ParameterizedTest
+import org.junit.jupiter.params.provider.ValueSource
 
 import scala.jdk.CollectionConverters._
 
 class FetchRequestDownConversionConfigTest extends BaseRequestTest {
   private var producer: KafkaProducer[String, String] = null
-  override def brokerCount: Int = 1
+  override def brokerCount: Int = 2
 
   @BeforeEach
   override def setUp(testInfo: TestInfo): Unit = {
@@ -64,8 +68,12 @@ class FetchRequestDownConversionConfigTest extends 
BaseRequestTest {
     topicConfig.setProperty(LogConfig.MinInSyncReplicasProp, 1.toString)
     configs.foreach { case (k, v) => topicConfig.setProperty(k, v) }
     topics.flatMap { topic =>
-      val partitionToLeader = createTopic(topic, numPartitions = 
numPartitions, replicationFactor = 1,
-        topicConfig = topicConfig)
+      val partitionToLeader = createTopic(
+        topic,
+        numPartitions = numPartitions,
+        replicationFactor = 2,
+        topicConfig = topicConfig
+      )
       partitionToLeader.map { case (partition, leader) => new 
TopicPartition(topic, partition) -> leader }
     }.toMap
   }
@@ -140,56 +148,101 @@ class FetchRequestDownConversionConfigTest extends 
BaseRequestTest {
    * Tests that "message.downconversion.enable" can be set at topic level, and 
its configuration is obeyed for client
    * fetch requests.
    */
-  @Test
-  def testV1FetchWithTopicLevelOverrides(): Unit = {
-    // create topics with default down-conversion configuration (i.e. 
conversion disabled)
-    val conversionDisabledTopicsMap = createTopics(numTopics = 5, 
numPartitions = 1, topicSuffixStart = 0)
-    val conversionDisabledTopicPartitions = 
conversionDisabledTopicsMap.keySet.toSeq
-
-    // create topics with down-conversion configuration enabled
-    val topicConfig = Map(LogConfig.MessageDownConversionEnableProp -> "true")
-    val conversionEnabledTopicsMap = createTopics(numTopics = 5, numPartitions 
= 1, topicConfig, topicSuffixStart = 5)
-    val conversionEnabledTopicPartitions = 
conversionEnabledTopicsMap.keySet.toSeq
-
-    val allTopics = conversionDisabledTopicPartitions ++ 
conversionEnabledTopicPartitions
-    val leaderId = conversionDisabledTopicsMap.head._2
-    val topicIds = servers.head.kafkaController.controllerContext.topicIds
-    val topicNames = topicIds.map(_.swap)
-
-    allTopics.foreach(tp => producer.send(new ProducerRecord(tp.topic(), 
"key", "value")).get())
-    val fetchRequest = FetchRequest.Builder.forConsumer(1, Int.MaxValue, 0, 
createPartitionMap(1024,
-      allTopics, topicIds.toMap)).build(1)
-    val fetchResponse = sendFetchRequest(leaderId, fetchRequest)
-
-    val fetchResponseData = fetchResponse.responseData(topicNames.asJava, 1)
-    conversionDisabledTopicPartitions.foreach(tp => 
assertEquals(Errors.UNSUPPORTED_VERSION, 
Errors.forCode(fetchResponseData.get(tp).errorCode)))
-    conversionEnabledTopicPartitions.foreach(tp => assertEquals(Errors.NONE, 
Errors.forCode(fetchResponseData.get(tp).errorCode)))
+  @ParameterizedTest(name = TestInfoUtils.TestWithParameterizedQuorumName)
+  @ValueSource(strings = Array("zk", "kraft"))
+  def testV1FetchFromConsumer(quorum: String): Unit = {
+    testV1Fetch(isFollowerFetch = false)
   }
 
   /**
    * Tests that "message.downconversion.enable" has no effect on fetch 
requests from replicas.
    */
-  @Test
-  def testV1FetchFromReplica(): Unit = {
-    // create topics with default down-conversion configuration (i.e. 
conversion disabled)
-    val conversionDisabledTopicsMap = createTopics(numTopics = 5, 
numPartitions = 1, topicSuffixStart = 0)
-    val conversionDisabledTopicPartitions = 
conversionDisabledTopicsMap.keySet.toSeq
+  @ParameterizedTest(name = TestInfoUtils.TestWithParameterizedQuorumName)
+  @ValueSource(strings = Array("zk", "kraft"))
+  def testV1FetchFromReplica(quorum: String): Unit = {
+    testV1Fetch(isFollowerFetch = true)
+  }
 
-    // create topics with down-conversion configuration enabled
-    val topicConfig = Map(LogConfig.MessageDownConversionEnableProp -> "true")
-    val conversionEnabledTopicsMap = createTopics(numTopics = 5, numPartitions 
= 1, topicConfig, topicSuffixStart = 5)
-    val conversionEnabledTopicPartitions = 
conversionEnabledTopicsMap.keySet.toSeq
+  def testV1Fetch(isFollowerFetch: Boolean): Unit = {
+    val topicWithDownConversionEnabled = "foo"
+    val topicWithDownConversionDisabled = "bar"
+    val replicaIds = brokers.map(_.config.brokerId)
+    val leaderId = replicaIds.head
+    val followerId = replicaIds.last
 
-    val allTopicPartitions = conversionDisabledTopicPartitions ++ 
conversionEnabledTopicPartitions
-    val topicIds = servers.head.kafkaController.controllerContext.topicIds
-    val topicNames = topicIds.map(_.swap)
-    val leaderId = conversionDisabledTopicsMap.head._2
+    val admin = createAdminClient()
+
+    val topicWithDownConversionDisabledId = TestUtils.createTopicWithAdminRaw(
+      admin,
+      topicWithDownConversionDisabled,
+      replicaAssignment = Map(0 -> replicaIds)
+    )
+
+    val topicConfig = new Properties
+    topicConfig.put(LogConfig.MessageDownConversionEnableProp, "true")
+    val topicWithDownConversionEnabledId = TestUtils.createTopicWithAdminRaw(
+      admin,
+      topicWithDownConversionEnabled,
+      replicaAssignment = Map(0 -> replicaIds),
+      topicConfig = topicConfig
+    )
+
+    val partitionWithDownConversionEnabled = new 
TopicPartition(topicWithDownConversionEnabled, 0)
+    val partitionWithDownConversionDisabled = new 
TopicPartition(topicWithDownConversionDisabled, 0)
+
+    val allTopicPartitions = Seq(
+      partitionWithDownConversionEnabled,
+      partitionWithDownConversionDisabled
+    )
+
+    allTopicPartitions.foreach { tp =>
+      producer.send(new ProducerRecord(tp.topic, "key", "value")).get()
+    }
+
+    val topicIdMap = Map(
+      topicWithDownConversionEnabled -> topicWithDownConversionEnabledId,
+      topicWithDownConversionDisabled -> topicWithDownConversionDisabledId
+    )
+
+    val fetchResponseData = sendFetch(
+      leaderId,
+      allTopicPartitions,
+      topicIdMap,
+      fetchVersion = 1,
+      replicaIdOpt = if (isFollowerFetch) Some(followerId) else None
+    )
+
+    def error(tp: TopicPartition): Errors = {
+      Errors.forCode(fetchResponseData.get(tp).errorCode)
+    }
+
+    assertEquals(Errors.NONE, error(partitionWithDownConversionEnabled))
+    if (isFollowerFetch) {
+      assertEquals(Errors.NONE, error(partitionWithDownConversionDisabled))
+    } else {
+      assertEquals(Errors.UNSUPPORTED_VERSION, 
error(partitionWithDownConversionDisabled))
+    }
+  }
+
+  private def sendFetch(
+    leaderId: Int,
+    partitions: Seq[TopicPartition],
+    topicIdMap: Map[String, Uuid],
+    fetchVersion: Short,
+    replicaIdOpt: Option[Int]
+  ): util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData] = {
+    val topicNameMap = topicIdMap.map(_.swap)
+    val partitionMap = createPartitionMap(1024, partitions, topicIdMap)
+
+    val fetchRequest = replicaIdOpt.map { replicaId =>
+      FetchRequest.Builder.forReplica(fetchVersion, replicaId, Int.MaxValue, 
0, partitionMap)
+        .build(fetchVersion)
+    }.getOrElse {
+      FetchRequest.Builder.forConsumer(fetchVersion, Int.MaxValue, 0, 
partitionMap)
+        .build(fetchVersion)
+    }
 
-    allTopicPartitions.foreach(tp => producer.send(new 
ProducerRecord(tp.topic, "key", "value")).get())
-    val fetchRequest = FetchRequest.Builder.forReplica(1, 1, Int.MaxValue, 0,
-      createPartitionMap(1024, allTopicPartitions, topicIds.toMap)).build()
     val fetchResponse = sendFetchRequest(leaderId, fetchRequest)
-    val fetchResponseData = fetchResponse.responseData(topicNames.asJava, 1)
-    allTopicPartitions.foreach(tp => assertEquals(Errors.NONE, 
Errors.forCode(fetchResponseData.get(tp).errorCode)))
+    fetchResponse.responseData(topicNameMap.asJava, fetchVersion)
   }
 }
diff --git 
a/core/src/test/scala/unit/kafka/server/ReplicaManagerQuotasTest.scala 
b/core/src/test/scala/unit/kafka/server/ReplicaManagerQuotasTest.scala
index 49ac23ec23..18e810bb88 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaManagerQuotasTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerQuotasTest.scala
@@ -18,8 +18,7 @@ package kafka.server
 
 import java.io.File
 import java.util.{Collections, Optional, Properties}
-
-import kafka.cluster.Partition
+import kafka.cluster.{Partition, PartitionTest}
 import kafka.log.{LogManager, LogOffsetSnapshot, UnifiedLog}
 import kafka.server.QuotaFactory.QuotaManagers
 import kafka.utils._
@@ -32,9 +31,9 @@ import org.apache.kafka.common.{TopicIdPartition, 
TopicPartition, Uuid}
 import org.apache.kafka.metadata.LeaderRecoveryState
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, Test}
-import org.mockito.{AdditionalMatchers, ArgumentMatchers}
 import org.mockito.ArgumentMatchers.{any, anyBoolean, anyInt, anyLong}
 import org.mockito.Mockito.{mock, when}
+import org.mockito.{AdditionalMatchers, ArgumentMatchers}
 
 import scala.jdk.CollectionConverters._
 
@@ -65,18 +64,10 @@ class ReplicaManagerQuotasTest {
       .thenReturn(false)
       .thenReturn(true)
 
-    val fetch = replicaManager.readFromLocalLog(
-      replicaId = followerReplicaId,
-      fetchOnlyFromLeader = true,
-      fetchIsolation = FetchHighWatermark,
-      fetchMaxBytes = Int.MaxValue,
-      hardMaxBytesLimit = false,
-      readPartitionInfo = fetchInfo,
-      quota = quota,
-      clientMetadata = None)
+    val fetchParams = PartitionTest.followerFetchParams(followerReplicaId)
+    val fetch = replicaManager.readFromLocalLog(fetchParams, fetchInfo, quota, 
readFromPurgatory = false)
     assertEquals(1, fetch.find(_._1 == 
topicIdPartition1).get._2.info.records.batches.asScala.size,
       "Given two partitions, with only one throttled, we should get the first")
-
     assertEquals(0, fetch.find(_._1 == 
topicIdPartition2).get._2.info.records.batches.asScala.size,
       "But we shouldn't get the second")
   }
@@ -91,15 +82,8 @@ class ReplicaManagerQuotasTest {
       .thenReturn(true)
       .thenReturn(true)
 
-    val fetch = replicaManager.readFromLocalLog(
-      replicaId = followerReplicaId,
-      fetchOnlyFromLeader = true,
-      fetchIsolation = FetchHighWatermark,
-      fetchMaxBytes = Int.MaxValue,
-      hardMaxBytesLimit = false,
-      readPartitionInfo = fetchInfo,
-      quota = quota,
-      clientMetadata = None)
+    val fetchParams = PartitionTest.followerFetchParams(followerReplicaId)
+    val fetch = replicaManager.readFromLocalLog(fetchParams, fetchInfo, quota, 
readFromPurgatory = false)
     assertEquals(0, fetch.find(_._1 == 
topicIdPartition1).get._2.info.records.batches.asScala.size,
       "Given two partitions, with both throttled, we should get no messages")
     assertEquals(0, fetch.find(_._1 == 
topicIdPartition2).get._2.info.records.batches.asScala.size,
@@ -116,15 +100,8 @@ class ReplicaManagerQuotasTest {
       .thenReturn(false)
       .thenReturn(false)
 
-    val fetch = replicaManager.readFromLocalLog(
-      replicaId = followerReplicaId,
-      fetchOnlyFromLeader = true,
-      fetchIsolation = FetchHighWatermark,
-      fetchMaxBytes = Int.MaxValue,
-      hardMaxBytesLimit = false,
-      readPartitionInfo = fetchInfo,
-      quota = quota,
-      clientMetadata = None)
+    val fetchParams = PartitionTest.followerFetchParams(followerReplicaId)
+    val fetch = replicaManager.readFromLocalLog(fetchParams, fetchInfo, quota, 
readFromPurgatory = false)
     assertEquals(1, fetch.find(_._1 == 
topicIdPartition1).get._2.info.records.batches.asScala.size,
       "Given two partitions, with both non-throttled, we should get both 
messages")
     assertEquals(1, fetch.find(_._1 == 
topicIdPartition2).get._2.info.records.batches.asScala.size,
@@ -141,15 +118,8 @@ class ReplicaManagerQuotasTest {
       .thenReturn(false)
       .thenReturn(true)
 
-    val fetch = replicaManager.readFromLocalLog(
-      replicaId = followerReplicaId,
-      fetchOnlyFromLeader = true,
-      fetchIsolation = FetchHighWatermark,
-      fetchMaxBytes = Int.MaxValue,
-      hardMaxBytesLimit = false,
-      readPartitionInfo = fetchInfo,
-      quota = quota,
-      clientMetadata = None)
+    val fetchParams = PartitionTest.followerFetchParams(followerReplicaId)
+    val fetch = replicaManager.readFromLocalLog(fetchParams, fetchInfo, quota, 
readFromPurgatory = false)
     assertEquals(1, fetch.find(_._1 == 
topicIdPartition1).get._2.info.records.batches.asScala.size,
       "Given two partitions, with only one throttled, we should get the first")
 
@@ -164,19 +134,10 @@ class ReplicaManagerQuotasTest {
     val quota = mockQuota()
     when(quota.isQuotaExceeded).thenReturn(true)
 
-    val fetch = replicaManager.readFromLocalLog(
-      replicaId = FetchRequest.CONSUMER_REPLICA_ID,
-      fetchOnlyFromLeader = true,
-      fetchIsolation = FetchHighWatermark,
-      fetchMaxBytes = Int.MaxValue,
-      hardMaxBytesLimit = false,
-      readPartitionInfo = fetchInfo,
-      quota = quota,
-      clientMetadata = None).toMap
-
+    val fetchParams = PartitionTest.consumerFetchParams()
+    val fetch = replicaManager.readFromLocalLog(fetchParams, fetchInfo, quota, 
readFromPurgatory = false).toMap
     assertEquals(1, fetch(topicIdPartition1).info.records.batches.asScala.size,
       "Replication throttled partitions should return data for consumer fetch")
-
     assertEquals(1, fetch(topicIdPartition2).info.records.batches.asScala.size,
       "Replication throttled partitions should return data for consumer fetch")
   }
@@ -315,6 +276,10 @@ class ReplicaManagerQuotasTest {
         MemoryRecords.EMPTY
       ))
 
+    when(log.maybeIncrementHighWatermark(
+      any[LogOffsetMetadata]
+    )).thenReturn(None)
+
     //Create log manager
     val logManager: LogManager = mock(classOf[LogManager])
 
@@ -367,4 +332,5 @@ class ReplicaManagerQuotasTest {
     when(quota.isThrottled(any[TopicPartition])).thenReturn(true)
     quota
   }
+
 }
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala 
b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
index 977da6c69c..80e6112518 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
@@ -973,13 +973,14 @@ class ReplicaManagerTest {
       val partition0Replicas = Seq[Integer](0, 1).asJava
       val partition1Replicas = Seq[Integer](0, 2).asJava
       val topicIds = Map(tp0.topic -> topicId, tp1.topic -> topicId).asJava
+      val leaderEpoch = 0
       val leaderAndIsrRequest = new 
LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, 
brokerEpoch,
         Seq(
           new LeaderAndIsrPartitionState()
             .setTopicName(tp0.topic)
             .setPartitionIndex(tp0.partition)
             .setControllerEpoch(0)
-            .setLeader(0)
+            .setLeader(leaderEpoch)
             .setLeaderEpoch(0)
             .setIsr(partition0Replicas)
             .setPartitionEpoch(0)
@@ -990,7 +991,7 @@ class ReplicaManagerTest {
             .setPartitionIndex(tp1.partition)
             .setControllerEpoch(0)
             .setLeader(0)
-            .setLeaderEpoch(0)
+            .setLeaderEpoch(leaderEpoch)
             .setIsr(partition1Replicas)
             .setPartitionEpoch(0)
             .setReplicas(partition1Replicas)
@@ -1024,20 +1025,17 @@ class ReplicaManagerTest {
         assertEquals(Errors.NONE, tp0Status.get.error)
         assertTrue(tp0Status.get.records.batches.iterator.hasNext)
 
+        // Replica 1 is not a valid replica for partition 1
         val tp1Status = responseStatusMap.get(tidp1)
-        assertTrue(tp1Status.isDefined)
-        assertEquals(0, tp1Status.get.highWatermark)
-        assertEquals(Some(0), tp0Status.get.lastStableOffset)
-        assertEquals(Errors.NONE, tp1Status.get.error)
-        assertFalse(tp1Status.get.records.batches.iterator.hasNext)
+        assertEquals(Errors.UNKNOWN_LEADER_EPOCH, tp1Status.get.error)
       }
 
       fetchPartitions(
         replicaManager,
         replicaId = 1,
         fetchInfos = Seq(
-          tidp0 -> new PartitionData(Uuid.ZERO_UUID, 1, 0, 100000, 
Optional.empty()),
-          tidp1 -> new PartitionData(Uuid.ZERO_UUID, 1, 0, 100000, 
Optional.empty())
+          tidp0 -> new PartitionData(Uuid.ZERO_UUID, 1, 0, 100000, 
Optional.of[Integer](leaderEpoch)),
+          tidp1 -> new PartitionData(Uuid.ZERO_UUID, 1, 0, 100000, 
Optional.of[Integer](leaderEpoch))
         ),
         responseCallback = fetchCallback,
         maxWaitMs = 1000,
@@ -1354,13 +1352,14 @@ class ReplicaManagerTest {
       ).toMap)
 
       // Make this replica the leader
+      val leaderEpoch = 1
       val leaderAndIsrRequest = new 
LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, 
brokerEpoch,
         Seq(new LeaderAndIsrPartitionState()
           .setTopicName(topic)
           .setPartitionIndex(0)
           .setControllerEpoch(0)
           .setLeader(0)
-          .setLeaderEpoch(1)
+          .setLeaderEpoch(leaderEpoch)
           .setIsr(brokerList)
           .setPartitionEpoch(0)
           .setReplicas(brokerList)
@@ -1368,15 +1367,22 @@ class ReplicaManagerTest {
         Collections.singletonMap(topic, topicId),
         Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build()
       replicaManager.becomeLeaderOrFollower(1, leaderAndIsrRequest, (_, _) => 
())
-      // Avoid the replica selector ignore the follower replica if it not have 
the data that need to fetch
-      
replicaManager.getPartitionOrException(tp0).updateFollowerFetchState(followerBrokerId,
 new LogOffsetMetadata(0), 0, 0, 0)
+
+      // The leader must record the follower's fetch offset to make it 
eligible for follower fetch selection
+      val followerFetchData = new PartitionData(topicId, 0L, 0L, Int.MaxValue, 
Optional.of(Int.box(leaderEpoch)), Optional.empty[Integer])
+      fetchPartitionAsFollower(
+        replicaManager,
+        tidp0,
+        followerFetchData,
+        replicaId = followerBrokerId
+      )
 
       val metadata = new DefaultClientMetadata("rack-b", "client-id",
         InetAddress.getLocalHost, KafkaPrincipal.ANONYMOUS, "default")
 
       // If a preferred read replica is selected, the fetch response returns 
immediately, even if min bytes and timeout conditions are not met.
       val consumerResult = fetchPartitionAsConsumer(replicaManager, tidp0,
-        new PartitionData(Uuid.ZERO_UUID, 0, 0, 100000, Optional.empty()),
+        new PartitionData(topicId, 0, 0, 100000, Optional.empty()),
         minBytes = 1, clientMetadata = Some(metadata), maxWaitMs = 5000)
 
       // Fetch from leader succeeds
diff --git a/core/src/test/scala/unit/kafka/utils/TestUtils.scala 
b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
index 31ba10f79c..e097dbd620 100755
--- a/core/src/test/scala/unit/kafka/utils/TestUtils.scala
+++ b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
@@ -32,7 +32,7 @@ import java.util.{Arrays, Collections, Optional, Properties}
 import com.yammer.metrics.core.{Gauge, Meter}
 import javax.net.ssl.X509TrustManager
 import kafka.api._
-import kafka.cluster.{Broker, EndPoint, AlterPartitionListener}
+import kafka.cluster.{AlterPartitionListener, Broker, EndPoint}
 import kafka.controller.{ControllerEventManager, LeaderIsrAndControllerEpoch}
 import kafka.log._
 import kafka.network.RequestChannel
@@ -66,7 +66,7 @@ import org.apache.kafka.common.security.auth.{KafkaPrincipal, 
KafkaPrincipalSerd
 import org.apache.kafka.common.serialization.{ByteArrayDeserializer, 
ByteArraySerializer, Deserializer, IntegerSerializer, Serializer}
 import org.apache.kafka.common.utils.Utils._
 import org.apache.kafka.common.utils.{Time, Utils}
-import org.apache.kafka.common.{KafkaFuture, TopicPartition}
+import org.apache.kafka.common.{KafkaFuture, TopicPartition, Uuid}
 import org.apache.kafka.controller.QuorumController
 import org.apache.kafka.server.authorizer.{AuthorizableRequestContext, 
Authorizer => JAuthorizer}
 import org.apache.kafka.server.common.MetadataVersion
@@ -382,6 +382,34 @@ object TestUtils extends Logging {
     Admin.create(adminClientProperties)
   }
 
+  def createTopicWithAdminRaw[B <: KafkaBroker](
+    admin: Admin,
+    topic: String,
+    numPartitions: Int = 1,
+    replicationFactor: Int = 1,
+    replicaAssignment: collection.Map[Int, Seq[Int]] = Map.empty,
+    topicConfig: Properties = new Properties,
+  ): Uuid = {
+    val configsMap = new util.HashMap[String, String]()
+    topicConfig.forEach((k, v) => configsMap.put(k.toString, v.toString))
+
+    val result = if (replicaAssignment.isEmpty) {
+      admin.createTopics(Collections.singletonList(new NewTopic(
+        topic, numPartitions, replicationFactor.toShort).configs(configsMap)))
+    } else {
+      val assignment = new util.HashMap[Integer, util.List[Integer]]()
+      replicaAssignment.forKeyValue { case (k, v) =>
+        val replicas = new util.ArrayList[Integer]
+        v.foreach(r => replicas.add(r.asInstanceOf[Integer]))
+        assignment.put(k.asInstanceOf[Integer], replicas)
+      }
+      admin.createTopics(Collections.singletonList(new NewTopic(
+        topic, assignment).configs(configsMap)))
+    }
+
+    result.topicId(topic).get()
+}
+
   def createTopicWithAdmin[B <: KafkaBroker](
     admin: Admin,
     topic: String,
@@ -397,23 +425,15 @@ object TestUtils extends Logging {
       replicaAssignment.size
     }
 
-    val configsMap = new util.HashMap[String, String]()
-    topicConfig.forEach((k, v) => configsMap.put(k.toString, v.toString))
     try {
-      val result = if (replicaAssignment.isEmpty) {
-        admin.createTopics(Collections.singletonList(new NewTopic(
-          topic, numPartitions, 
replicationFactor.toShort).configs(configsMap)))
-      } else {
-        val assignment = new util.HashMap[Integer, util.List[Integer]]()
-        replicaAssignment.forKeyValue { case (k, v) =>
-          val replicas = new util.ArrayList[Integer]
-          v.foreach(r => replicas.add(r.asInstanceOf[Integer]))
-          assignment.put(k.asInstanceOf[Integer], replicas)
-        }
-        admin.createTopics(Collections.singletonList(new NewTopic(
-          topic, assignment).configs(configsMap)))
-      }
-      result.all().get()
+      createTopicWithAdminRaw(
+        admin,
+        topic,
+        numPartitions,
+        replicationFactor,
+        replicaAssignment,
+        topicConfig
+      )
     } catch {
       case e: ExecutionException => if (!(e.getCause != null &&
           e.getCause.isInstanceOf[TopicExistsException] &&
@@ -432,16 +452,24 @@ object TestUtils extends Logging {
     }.toMap
   }
 
+  def describeTopic(
+    admin: Admin,
+    topic: String
+  ): TopicDescription = {
+    val describedTopics = admin.describeTopics(
+      Collections.singleton(topic)
+    ).allTopicNames().get()
+    describedTopics.get(topic)
+  }
+
   def topicHasSameNumPartitionsAndReplicationFactor(adminClient: Admin,
                                                     topic: String,
                                                     numPartitions: Int,
                                                     replicationFactor: Int): 
Boolean = {
-    val describedTopics = adminClient.describeTopics(Collections.
-      singleton(topic)).allTopicNames().get()
-    val description = describedTopics.get(topic)
-    (description != null &&
+    val description = describeTopic(adminClient, topic)
+    description != null &&
       description.partitions().size() == numPartitions &&
-      description.partitions().iterator().next().replicas().size() == 
replicationFactor)
+      description.partitions().iterator().next().replicas().size() == 
replicationFactor
   }
 
   def createOffsetsTopicWithAdmin[B <: KafkaBroker](
diff --git 
a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/partition/UpdateFollowerFetchStateBenchmark.java
 
b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/partition/UpdateFollowerFetchStateBenchmark.java
index f1f3d76ba7..b2cf1ac556 100644
--- 
a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/partition/UpdateFollowerFetchStateBenchmark.java
+++ 
b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/partition/UpdateFollowerFetchStateBenchmark.java
@@ -20,6 +20,7 @@ package org.apache.kafka.jmh.partition;
 import kafka.cluster.DelayedOperations;
 import kafka.cluster.AlterPartitionListener;
 import kafka.cluster.Partition;
+import kafka.cluster.Replica;
 import kafka.log.CleanerConfig;
 import kafka.log.Defaults;
 import kafka.log.LogConfig;
@@ -79,6 +80,8 @@ public class UpdateFollowerFetchStateBenchmark {
     private long nextOffset = 0;
     private LogManager logManager;
     private Partition partition;
+    private Replica replica1;
+    private Replica replica2;
 
     @Setup(Level.Trial)
     public void setUp() {
@@ -127,6 +130,8 @@ public class UpdateFollowerFetchStateBenchmark {
                 alterPartitionListener, delayedOperations,
                 Mockito.mock(MetadataCache.class), logManager, 
alterPartitionManager);
         partition.makeLeader(partitionState, offsetCheckpoints, topicId);
+        replica1 = partition.getReplica(1).get();
+        replica2 = partition.getReplica(2).get();
     }
 
     // avoid mocked DelayedOperations to avoid mocked class affecting 
benchmark results
@@ -166,9 +171,9 @@ public class UpdateFollowerFetchStateBenchmark {
     @OutputTimeUnit(TimeUnit.NANOSECONDS)
     public void updateFollowerFetchStateBench() {
         // measure the impact of two follower fetches on the leader
-        partition.updateFollowerFetchState(1, new 
LogOffsetMetadata(nextOffset, nextOffset, 0),
+        partition.updateFollowerFetchState(replica1, new 
LogOffsetMetadata(nextOffset, nextOffset, 0),
                 0, 1, nextOffset);
-        partition.updateFollowerFetchState(2, new 
LogOffsetMetadata(nextOffset, nextOffset, 0),
+        partition.updateFollowerFetchState(replica2, new 
LogOffsetMetadata(nextOffset, nextOffset, 0),
                 0, 1, nextOffset);
         nextOffset++;
     }
@@ -178,9 +183,9 @@ public class UpdateFollowerFetchStateBenchmark {
     public void updateFollowerFetchStateBenchNoChange() {
         // measure the impact of two follower fetches on the leader when the 
follower didn't
         // end up fetching anything
-        partition.updateFollowerFetchState(1, new 
LogOffsetMetadata(nextOffset, nextOffset, 0),
+        partition.updateFollowerFetchState(replica1, new 
LogOffsetMetadata(nextOffset, nextOffset, 0),
                 0, 1, 100);
-        partition.updateFollowerFetchState(2, new 
LogOffsetMetadata(nextOffset, nextOffset, 0),
+        partition.updateFollowerFetchState(replica2, new 
LogOffsetMetadata(nextOffset, nextOffset, 0),
                 0, 1, 100);
     }
 }

Reply via email to