This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new f6890e7  MINOR: Next round of fetcher thread consolidation (#5587)
f6890e7 is described below

commit f6890e78687afbbe09ff34b0a9383b548be77ee6
Author: Jason Gustafson <[email protected]>
AuthorDate: Fri Aug 31 13:10:27 2018 -0700

    MINOR: Next round of fetcher thread consolidation (#5587)
    
    Pull the epoch request build logic up to `AbstractFetcherThread`. Also get 
rid of the `FetchRequest` indirection.
    
    Reviewers: Ismael Juma <[email protected]>, Rajini Sivaram 
<[email protected]>
---
 .../kafka/common/internals/PartitionStates.java    |   4 +
 .../scala/kafka/server/AbstractFetcherThread.scala | 158 ++++++++++++++-------
 .../kafka/server/ReplicaAlterLogDirsThread.scala   |  96 ++++---------
 .../scala/kafka/server/ReplicaFetcherThread.scala  | 118 +++++----------
 .../ReplicaFetcherThreadFatalErrorTest.scala       |   7 +-
 .../kafka/server/AbstractFetcherThreadTest.scala   |  53 ++++---
 .../server/ReplicaAlterLogDirsThreadTest.scala     |  46 +++---
 7 files changed, 222 insertions(+), 260 deletions(-)

diff --git 
a/clients/src/main/java/org/apache/kafka/common/internals/PartitionStates.java 
b/clients/src/main/java/org/apache/kafka/common/internals/PartitionStates.java
index 5b904c2..ba65632 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/internals/PartitionStates.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/internals/PartitionStates.java
@@ -93,6 +93,10 @@ public class PartitionStates<S> {
         return result;
     }
 
+    public LinkedHashMap<TopicPartition, S> partitionStateMap() {
+        return new LinkedHashMap<>(map);
+    }
+
     /**
      * Returns the partition state values in order.
      */
diff --git a/core/src/main/scala/kafka/server/AbstractFetcherThread.scala 
b/core/src/main/scala/kafka/server/AbstractFetcherThread.scala
index fe9fc06..e753f6e 100755
--- a/core/src/main/scala/kafka/server/AbstractFetcherThread.scala
+++ b/core/src/main/scala/kafka/server/AbstractFetcherThread.scala
@@ -39,7 +39,7 @@ import com.yammer.metrics.core.Gauge
 import org.apache.kafka.common.{KafkaException, TopicPartition}
 import org.apache.kafka.common.internals.{FatalExitError, PartitionStates}
 import org.apache.kafka.common.record.{FileRecords, MemoryRecords, Records}
-import org.apache.kafka.common.requests.{EpochEndOffset, FetchResponse}
+import org.apache.kafka.common.requests.{EpochEndOffset, FetchRequest, 
FetchResponse}
 
 import scala.math._
 
@@ -54,7 +54,6 @@ abstract class AbstractFetcherThread(name: String,
                                      includeLogTruncation: Boolean)
   extends ShutdownableThread(name, isInterruptible) {
 
-  type REQ <: FetchRequest
   type PD = FetchResponse.PartitionData[Records]
 
   private[server] val partitionStates = new 
PartitionStates[PartitionFetchState]
@@ -74,18 +73,15 @@ abstract class AbstractFetcherThread(name: String,
   // handle a partition whose offset is out of range and return a new fetch 
offset
   protected def handleOffsetOutOfRange(topicPartition: TopicPartition): Long
 
-  // deal with partitions with errors, potentially due to leadership changes
-  protected def handlePartitionsWithErrors(partitions: 
Iterable[TopicPartition])
-
-  protected def buildLeaderEpochRequest(allPartitions: Seq[(TopicPartition, 
PartitionFetchState)]): ResultWithPartitions[Map[TopicPartition, Int]]
-
   protected def fetchEpochsFromLeader(partitions: Map[TopicPartition, Int]): 
Map[TopicPartition, EpochEndOffset]
 
-  protected def maybeTruncate(fetchedEpochs: Map[TopicPartition, 
EpochEndOffset]): ResultWithPartitions[Map[TopicPartition, 
OffsetTruncationState]]
+  protected def truncate(topicPartition: TopicPartition, epochEndOffset: 
EpochEndOffset): OffsetTruncationState
 
-  protected def buildFetchRequest(partitionMap: Seq[(TopicPartition, 
PartitionFetchState)]): ResultWithPartitions[REQ]
+  protected def buildFetch(partitionMap: Map[TopicPartition, 
PartitionFetchState]): ResultWithPartitions[Option[FetchRequest.Builder]]
 
-  protected def fetch(fetchRequest: REQ): Seq[(TopicPartition, PD)]
+  protected def fetch(fetchRequest: FetchRequest.Builder): 
Seq[(TopicPartition, PD)]
+
+  protected def getReplica(tp: TopicPartition): Option[Replica]
 
   override def shutdown() {
     initiateShutdown()
@@ -99,34 +95,67 @@ abstract class AbstractFetcherThread(name: String,
     fetcherLagStats.unregister()
   }
 
-  private def states() = partitionStates.partitionStates.asScala.map { state 
=> state.topicPartition -> state.value }
-
   override def doWork() {
     maybeTruncate()
-    val fetchRequest = inLock(partitionMapLock) {
-      val ResultWithPartitions(fetchRequest, partitionsWithError) = 
buildFetchRequest(states)
-      if (fetchRequest.isEmpty) {
+
+    val (fetchStates, fetchRequestOpt) = inLock(partitionMapLock) {
+      val fetchStates = partitionStates.partitionStateMap.asScala
+      val ResultWithPartitions(fetchRequestOpt, partitionsWithError) = 
buildFetch(fetchStates)
+
+      handlePartitionsWithErrors(partitionsWithError)
+
+      if (fetchRequestOpt.isEmpty) {
         trace(s"There are no active partitions. Back off for $fetchBackOffMs 
ms before sending a fetch request")
         partitionMapCond.await(fetchBackOffMs, TimeUnit.MILLISECONDS)
       }
-      handlePartitionsWithErrors(partitionsWithError)
-      fetchRequest
+
+      (fetchStates, fetchRequestOpt)
+    }
+
+    fetchRequestOpt.foreach { fetchRequest =>
+      processFetchRequest(fetchStates, fetchRequest)
+    }
+  }
+
+  // deal with partitions with errors, potentially due to leadership changes
+  private def handlePartitionsWithErrors(partitions: Iterable[TopicPartition]) 
{
+    if (partitions.nonEmpty)
+      delayPartitions(partitions, fetchBackOffMs)
+  }
+
+  /**
+   * Builds offset for leader epoch requests for partitions that are in the 
truncating phase based
+   * on latest epochs of the future replicas (the one that is fetching)
+   */
+  private def buildLeaderEpochRequest(): 
ResultWithPartitions[Map[TopicPartition, Int]] = inLock(partitionMapLock) {
+    var partitionsWithoutEpochs = mutable.Set.empty[TopicPartition]
+    var partitionsWithEpochs = mutable.Map.empty[TopicPartition, Int]
+
+    partitionStates.partitionStates.asScala.foreach { state =>
+      val tp = state.topicPartition
+      if (state.value.isTruncatingLog) {
+        getReplica(tp).flatMap(_.epochs).map(_.latestEpoch) match {
+          case Some(latestEpoch) => partitionsWithEpochs += tp -> latestEpoch
+          case None => partitionsWithoutEpochs += tp
+        }
+      }
     }
-    if (!fetchRequest.isEmpty)
-      processFetchRequest(fetchRequest)
+
+    debug(s"Build leaderEpoch request $partitionsWithEpochs")
+    ResultWithPartitions(partitionsWithEpochs, partitionsWithoutEpochs)
   }
 
   /**
     * - Build a leader epoch fetch based on partitions that are in the 
Truncating phase
-    * - Issue LeaderEpochRequeust, retrieving the latest offset for each 
partition's
+    * - Send OffsetsForLeaderEpochRequest, retrieving the latest offset for 
each partition's
     *   leader epoch. This is the offset the follower should truncate to ensure
     *   accurate log replication.
     * - Finally truncate the logs for partitions in the truncating phase and 
mark them
     *   truncation complete. Do this within a lock to ensure no leadership 
changes can
     *   occur during truncation.
     */
-  def maybeTruncate(): Unit = {
-    val ResultWithPartitions(epochRequests, partitionsWithError) = 
inLock(partitionMapLock) { buildLeaderEpochRequest(states) }
+  private def maybeTruncate(): Unit = {
+    val ResultWithPartitions(epochRequests, partitionsWithError) = 
buildLeaderEpochRequest()
     handlePartitionsWithErrors(partitionsWithError)
 
     if (epochRequests.nonEmpty) {
@@ -142,7 +171,31 @@ abstract class AbstractFetcherThread(name: String,
     }
   }
 
-  private def processFetchRequest(fetchRequest: REQ) {
+  private def maybeTruncate(fetchedEpochs: Map[TopicPartition, 
EpochEndOffset]): ResultWithPartitions[Map[TopicPartition, 
OffsetTruncationState]] = {
+    val fetchOffsets = mutable.HashMap.empty[TopicPartition, 
OffsetTruncationState]
+    val partitionsWithError = mutable.Set[TopicPartition]()
+
+    fetchedEpochs.foreach { case (tp, leaderEpochOffset) =>
+      try {
+        if (leaderEpochOffset.hasError) {
+          info(s"Retrying leaderEpoch request for partition $tp as the leader 
reported an error: ${leaderEpochOffset.error}")
+          partitionsWithError += tp
+        } else {
+          val offsetTruncationState = truncate(tp, leaderEpochOffset)
+          fetchOffsets.put(tp, offsetTruncationState)
+        }
+      } catch {
+        case e: KafkaStorageException =>
+          info(s"Failed to truncate $tp", e)
+          partitionsWithError += tp
+      }
+    }
+
+    ResultWithPartitions(fetchOffsets, partitionsWithError)
+  }
+
+  private def processFetchRequest(fetchStates: Map[TopicPartition, 
PartitionFetchState],
+                                  fetchRequest: FetchRequest.Builder): Unit = {
     val partitionsWithError = mutable.Set[TopicPartition]()
     var responseData: Seq[(TopicPartition, PD)] = Seq.empty
 
@@ -169,13 +222,12 @@ abstract class AbstractFetcherThread(name: String,
       inLock(partitionMapLock) {
 
         responseData.foreach { case (topicPartition, partitionData) =>
-          val topic = topicPartition.topic
-          val partitionId = topicPartition.partition
-          
Option(partitionStates.stateValue(topicPartition)).foreach(currentPartitionFetchState
 =>
+          Option(partitionStates.stateValue(topicPartition)).foreach { 
currentPartitionFetchState =>
             // It's possible that a partition is removed and re-added or 
truncated when there is a pending fetch request.
-            // In this case, we only want to process the fetch response if the 
partition state is ready for fetch and the current offset is the same as the 
offset requested.
-            if (fetchRequest.offset(topicPartition) == 
currentPartitionFetchState.fetchOffset &&
-                currentPartitionFetchState.isReadyForFetch) {
+            // In this case, we only want to process the fetch response if the 
partition state is ready for fetch and
+            // the current offset is the same as the offset requested.
+            val fetchOffset = fetchStates(topicPartition).fetchOffset
+            if (fetchOffset == currentPartitionFetchState.fetchOffset && 
currentPartitionFetchState.isReadyForFetch) {
               partitionData.error match {
                 case Errors.NONE =>
                   try {
@@ -183,7 +235,7 @@ abstract class AbstractFetcherThread(name: String,
                     val newOffset = 
records.batches.asScala.lastOption.map(_.nextOffset).getOrElse(
                       currentPartitionFetchState.fetchOffset)
 
-                    fetcherLagStats.getAndMaybePut(topic, partitionId).lag = 
Math.max(0L, partitionData.highWatermark - newOffset)
+                    fetcherLagStats.getAndMaybePut(topicPartition).lag = 
Math.max(0L, partitionData.highWatermark - newOffset)
                     // Once we hand off the partition data to the subclass, we 
can't mess with it any more in this thread
                     processPartitionData(topicPartition, 
currentPartitionFetchState.fetchOffset, partitionData, records)
 
@@ -232,7 +284,8 @@ abstract class AbstractFetcherThread(name: String,
                     partitionData.error.exception)
                   partitionsWithError += topicPartition
               }
-            })
+            }
+          }
         }
       }
     }
@@ -270,8 +323,11 @@ abstract class AbstractFetcherThread(name: String,
             new PartitionFetchState(initialFetchOffset, includeLogTruncation)
         tp -> fetchState
       }
-      val existingPartitionToState = states().toMap
-      partitionStates.set((existingPartitionToState ++ 
newPartitionToState).asJava)
+
+      newPartitionToState.foreach { case (tp, state) =>
+        partitionStates.updateAndMoveToEnd(tp, state)
+      }
+
       partitionMapCond.signalAll()
     } finally partitionMapLock.unlock()
   }
@@ -373,7 +429,8 @@ abstract class AbstractFetcherThread(name: String,
       for (partition <- partitions) {
         Option(partitionStates.stateValue(partition)).foreach 
(currentPartitionFetchState =>
           if (!currentPartitionFetchState.isDelayed)
-            partitionStates.updateAndMoveToEnd(partition, 
PartitionFetchState(currentPartitionFetchState.fetchOffset, new 
DelayedItem(delay), currentPartitionFetchState.truncatingLog))
+            partitionStates.updateAndMoveToEnd(partition, 
PartitionFetchState(currentPartitionFetchState.fetchOffset,
+              new DelayedItem(delay), 
currentPartitionFetchState.truncatingLog))
         )
       }
       partitionMapCond.signalAll()
@@ -385,7 +442,7 @@ abstract class AbstractFetcherThread(name: String,
     try {
       topicPartitions.foreach { topicPartition =>
         partitionStates.remove(topicPartition)
-        fetcherLagStats.unregister(topicPartition.topic, 
topicPartition.partition)
+        fetcherLagStats.unregister(topicPartition)
       }
     } finally partitionMapLock.unlock()
   }
@@ -397,8 +454,8 @@ abstract class AbstractFetcherThread(name: String,
   }
 
   private[server] def partitionsAndOffsets: Map[TopicPartition, 
BrokerAndInitialOffset] = inLock(partitionMapLock) {
-    partitionStates.partitionStates.asScala.map { case state =>
-      state.topicPartition -> new BrokerAndInitialOffset(sourceBroker, 
state.value.fetchOffset)
+    partitionStates.partitionStates.asScala.map { state =>
+      state.topicPartition -> BrokerAndInitialOffset(sourceBroker, 
state.value.fetchOffset)
     }.toMap
   }
 
@@ -418,11 +475,6 @@ object AbstractFetcherThread {
 
   case class ResultWithPartitions[R](result: R, partitionsWithError: 
Set[TopicPartition])
 
-  trait FetchRequest {
-    def isEmpty: Boolean
-    def offset(topicPartition: TopicPartition): Long
-  }
-
 }
 
 object FetcherMetrics {
@@ -436,8 +488,8 @@ class FetcherLagMetrics(metricId: ClientIdTopicPartition) 
extends KafkaMetricsGr
   private[this] val lagVal = new AtomicLong(-1L)
   private[this] val tags = Map(
     "clientId" -> metricId.clientId,
-    "topic" -> metricId.topic,
-    "partition" -> metricId.partitionId.toString)
+    "topic" -> metricId.topicPartition.topic,
+    "partition" -> metricId.topicPartition.partition.toString)
 
   newGauge(FetcherMetrics.ConsumerLag,
     new Gauge[Long] {
@@ -461,26 +513,26 @@ class FetcherLagStats(metricId: ClientIdAndBroker) {
   private val valueFactory = (k: ClientIdTopicPartition) => new 
FetcherLagMetrics(k)
   val stats = new Pool[ClientIdTopicPartition, 
FetcherLagMetrics](Some(valueFactory))
 
-  def getAndMaybePut(topic: String, partitionId: Int): FetcherLagMetrics = {
-    stats.getAndMaybePut(ClientIdTopicPartition(metricId.clientId, topic, 
partitionId))
+  def getAndMaybePut(topicPartition: TopicPartition): FetcherLagMetrics = {
+    stats.getAndMaybePut(ClientIdTopicPartition(metricId.clientId, 
topicPartition))
   }
 
-  def isReplicaInSync(topic: String, partitionId: Int): Boolean = {
-    val fetcherLagMetrics = 
stats.get(ClientIdTopicPartition(metricId.clientId, topic, partitionId))
+  def isReplicaInSync(topicPartition: TopicPartition): Boolean = {
+    val fetcherLagMetrics = 
stats.get(ClientIdTopicPartition(metricId.clientId, topicPartition))
     if (fetcherLagMetrics != null)
       fetcherLagMetrics.lag <= 0
     else
       false
   }
 
-  def unregister(topic: String, partitionId: Int) {
-    val lagMetrics = stats.remove(ClientIdTopicPartition(metricId.clientId, 
topic, partitionId))
+  def unregister(topicPartition: TopicPartition) {
+    val lagMetrics = stats.remove(ClientIdTopicPartition(metricId.clientId, 
topicPartition))
     if (lagMetrics != null) lagMetrics.unregister()
   }
 
   def unregister() {
     stats.keys.toBuffer.foreach { key: ClientIdTopicPartition =>
-      unregister(key.topic, key.partitionId)
+      unregister(key.topicPartition)
     }
   }
 }
@@ -501,8 +553,8 @@ class FetcherStats(metricId: ClientIdAndBroker) extends 
KafkaMetricsGroup {
 
 }
 
-case class ClientIdTopicPartition(clientId: String, topic: String, 
partitionId: Int) {
-  override def toString = "%s-%s-%d".format(clientId, topic, partitionId)
+case class ClientIdTopicPartition(clientId: String, topicPartition: 
TopicPartition) {
+  override def toString: String = s"$clientId-$topicPartition"
 }
 
 /**
diff --git a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala 
b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
index 05dc356..1621201 100644
--- a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
+++ b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
@@ -20,18 +20,16 @@ package kafka.server
 import java.util
 
 import kafka.api.Request
-import kafka.cluster.BrokerEndPoint
+import kafka.cluster.{BrokerEndPoint, Replica}
 import kafka.server.AbstractFetcherThread.ResultWithPartitions
 import kafka.server.QuotaFactory.UnboundedQuota
-import kafka.server.ReplicaAlterLogDirsThread.FetchRequest
-import kafka.server.epoch.LeaderEpochCache
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.errors.KafkaStorageException
 import org.apache.kafka.common.protocol.{ApiKeys, Errors}
 import org.apache.kafka.common.record.{MemoryRecords, Records}
 import org.apache.kafka.common.requests.EpochEndOffset._
 import org.apache.kafka.common.requests.FetchResponse.PartitionData
-import org.apache.kafka.common.requests.{EpochEndOffset, FetchResponse, 
FetchRequest => JFetchRequest}
+import org.apache.kafka.common.requests.{EpochEndOffset, FetchRequest, 
FetchResponse}
 
 import scala.collection.JavaConverters._
 import scala.collection.{Map, Seq, Set, mutable}
@@ -49,15 +47,17 @@ class ReplicaAlterLogDirsThread(name: String,
                                 isInterruptible = false,
                                 includeLogTruncation = true) {
 
-  type REQ = FetchRequest
-
   private val replicaId = brokerConfig.brokerId
   private val maxBytes = brokerConfig.replicaFetchResponseMaxBytes
   private val fetchSize = brokerConfig.replicaFetchMaxBytes
 
-  def fetch(fetchRequest: FetchRequest): Seq[(TopicPartition, PD)] = {
+  protected def getReplica(tp: TopicPartition): Option[Replica] = {
+    replicaMgr.getReplica(tp, Request.FutureLocalReplicaId)
+  }
+
+  def fetch(fetchRequest: FetchRequest.Builder): Seq[(TopicPartition, PD)] = {
     var partitionData: Seq[(TopicPartition, 
FetchResponse.PartitionData[Records])] = null
-    val request = fetchRequest.underlying.build()
+    val request = fetchRequest.build()
 
     def processResponseCallback(responsePartitionData: Seq[(TopicPartition, 
FetchPartitionData)]) {
       partitionData = responsePartitionData.map { case (tp, data) =>
@@ -129,28 +129,6 @@ class ReplicaAlterLogDirsThread(name: String,
     }
   }
 
-  def handlePartitionsWithErrors(partitions: Iterable[TopicPartition]) {
-    if (partitions.nonEmpty)
-      delayPartitions(partitions, brokerConfig.replicaFetchBackoffMs.toLong)
-  }
-
-  /**
-   * Builds offset for leader epoch requests for partitions that are in the 
truncating phase based
-   * on latest epochs of the future replicas (the one that is fetching)
-   */
-  def buildLeaderEpochRequest(allPartitions: Seq[(TopicPartition, 
PartitionFetchState)]): ResultWithPartitions[Map[TopicPartition, Int]] = {
-    def epochCacheOpt(tp: TopicPartition): Option[LeaderEpochCache] = 
replicaMgr.getReplica(tp, Request.FutureLocalReplicaId).map(_.epochs.get)
-
-    val partitionEpochOpts = allPartitions
-      .filter { case (_, state) => state.isTruncatingLog }
-      .map { case (tp, _) => tp -> epochCacheOpt(tp) }.toMap
-
-    val (partitionsWithEpoch, partitionsWithoutEpoch) = 
partitionEpochOpts.partition { case (_, epochCacheOpt) => 
epochCacheOpt.nonEmpty }
-
-    val result = partitionsWithEpoch.map { case (tp, epochCacheOpt) => tp -> 
epochCacheOpt.get.latestEpoch }
-    ResultWithPartitions(result, partitionsWithoutEpoch.keys.toSet)
-  }
-
   /**
    * Fetches offset for leader epoch from local replica for each given topic 
partitions
    * @param partitions map of topic partition -> leader epoch of the future 
replica
@@ -183,34 +161,17 @@ class ReplicaAlterLogDirsThread(name: String,
    * the future replica may miss "mark for truncation" event and must use the 
offset for leader epoch
    * exchange with the current replica to truncate to the largest common log 
prefix for the topic partition
    */
-  def maybeTruncate(fetchedEpochs: Map[TopicPartition, EpochEndOffset]): 
ResultWithPartitions[Map[TopicPartition, OffsetTruncationState]] = {
-    val fetchOffsets = scala.collection.mutable.HashMap.empty[TopicPartition, 
OffsetTruncationState]
-    val partitionsWithError = mutable.Set[TopicPartition]()
-
-    fetchedEpochs.foreach { case (topicPartition, epochOffset) =>
-      try {
-        val futureReplica = replicaMgr.getReplicaOrException(topicPartition, 
Request.FutureLocalReplicaId)
-        val partition = replicaMgr.getPartition(topicPartition).get
-
-        if (epochOffset.hasError) {
-          info(s"Retrying leaderEpoch request for partition $topicPartition as 
the current replica reported an error: ${epochOffset.error}")
-          partitionsWithError += topicPartition
-        } else {
-          val offsetTruncationState = getOffsetTruncationState(topicPartition, 
epochOffset, futureReplica, isFutureReplica = true)
+  override def truncate(topicPartition: TopicPartition, epochEndOffset: 
EpochEndOffset): OffsetTruncationState = {
+    val futureReplica = replicaMgr.getReplicaOrException(topicPartition, 
Request.FutureLocalReplicaId)
+    val partition = replicaMgr.getPartition(topicPartition).get
 
-          partition.truncateTo(offsetTruncationState.offset, isFuture = true)
-          fetchOffsets.put(topicPartition, offsetTruncationState)
-        }
-      } catch {
-        case e: KafkaStorageException =>
-          info(s"Failed to truncate $topicPartition", e)
-          partitionsWithError += topicPartition
-      }
-    }
-    ResultWithPartitions(fetchOffsets, partitionsWithError)
+    val offsetTruncationState = getOffsetTruncationState(topicPartition, 
epochEndOffset, futureReplica,
+      isFutureReplica = true)
+    partition.truncateTo(offsetTruncationState.offset, isFuture = true)
+    offsetTruncationState
   }
 
-  def buildFetchRequest(partitionMap: Seq[(TopicPartition, 
PartitionFetchState)]): ResultWithPartitions[FetchRequest] = {
+  def buildFetch(partitionMap: Map[TopicPartition, PartitionFetchState]): 
ResultWithPartitions[Option[FetchRequest.Builder]] = {
     // Only include replica in the fetch request if it is not throttled.
     val maxPartitionOpt = partitionMap.filter { case (_, partitionFetchState) 
=>
       partitionFetchState.isReadyForFetch && !quota.isQuotaExceeded
@@ -223,32 +184,29 @@ class ReplicaAlterLogDirsThread(name: String,
 
     // Only move one replica at a time to increase its catch-up rate and thus 
reduce the time spent on moving any given replica
     // Replicas are ordered by their TopicPartition
-    val requestMap = new util.LinkedHashMap[TopicPartition, 
JFetchRequest.PartitionData]
+    val requestMap = new util.LinkedHashMap[TopicPartition, 
FetchRequest.PartitionData]
     val partitionsWithError = mutable.Set[TopicPartition]()
 
     if (maxPartitionOpt.nonEmpty) {
       val (topicPartition, partitionFetchState) = maxPartitionOpt.get
       try {
         val logStartOffset = replicaMgr.getReplicaOrException(topicPartition, 
Request.FutureLocalReplicaId).logStartOffset
-        requestMap.put(topicPartition, new 
JFetchRequest.PartitionData(partitionFetchState.fetchOffset, logStartOffset, 
fetchSize))
+        requestMap.put(topicPartition, new 
FetchRequest.PartitionData(partitionFetchState.fetchOffset, logStartOffset, 
fetchSize))
       } catch {
         case _: KafkaStorageException =>
           partitionsWithError += topicPartition
       }
     }
-    // Set maxWait and minBytes to 0 because the response should return 
immediately if
-    // the future log has caught up with the current log of the partition
-    val requestBuilder = 
JFetchRequest.Builder.forReplica(ApiKeys.FETCH.latestVersion, replicaId, 0, 0, 
requestMap).setMaxBytes(maxBytes)
-    ResultWithPartitions(new FetchRequest(requestBuilder), partitionsWithError)
-  }
-}
 
-object ReplicaAlterLogDirsThread {
-
-  private[server] class FetchRequest(val underlying: JFetchRequest.Builder) 
extends AbstractFetcherThread.FetchRequest {
-    def isEmpty: Boolean = underlying.fetchData.isEmpty
-    def offset(topicPartition: TopicPartition): Long = 
underlying.fetchData.asScala(topicPartition).fetchOffset
-    override def toString = underlying.toString
+    val fetchRequestOpt = if (requestMap.isEmpty) {
+      None
+    } else {
+      // Set maxWait and minBytes to 0 because the response should return 
immediately if
+      // the future log has caught up with the current log of the partition
+      Some(FetchRequest.Builder.forReplica(ApiKeys.FETCH.latestVersion, 
replicaId, 0, 0, requestMap)
+        .setMaxBytes(maxBytes))
+    }
+    ResultWithPartitions(fetchRequestOpt, partitionsWithError)
   }
 
 }
diff --git a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala 
b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
index 3b1a54f..5624e84 100644
--- a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
+++ b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
@@ -17,24 +17,20 @@
 
 package kafka.server
 
-import java.util
-
-import AbstractFetcherThread.ResultWithPartitions
 import kafka.api._
-import kafka.cluster.BrokerEndPoint
+import kafka.cluster.{BrokerEndPoint, Replica}
 import kafka.log.LogConfig
-import kafka.server.ReplicaFetcherThread._
-import kafka.server.epoch.LeaderEpochCache
+import kafka.server.AbstractFetcherThread.ResultWithPartitions
 import kafka.zk.AdminZkClient
 import org.apache.kafka.clients.FetchSessionHandler
-import org.apache.kafka.common.requests.EpochEndOffset._
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.errors.KafkaStorageException
 import org.apache.kafka.common.internals.FatalExitError
 import org.apache.kafka.common.metrics.Metrics
 import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.record.{MemoryRecords, Records}
-import org.apache.kafka.common.requests.{EpochEndOffset, FetchResponse, 
ListOffsetRequest, ListOffsetResponse, OffsetsForLeaderEpochRequest, 
OffsetsForLeaderEpochResponse, FetchRequest => JFetchRequest}
+import org.apache.kafka.common.requests.EpochEndOffset._
+import org.apache.kafka.common.requests._
 import org.apache.kafka.common.utils.{LogContext, Time}
 
 import scala.collection.JavaConverters._
@@ -56,8 +52,6 @@ class ReplicaFetcherThread(name: String,
                                 isInterruptible = false,
                                 includeLogTruncation = true) {
 
-  type REQ = FetchRequest
-
   private val replicaId = brokerConfig.brokerId
   private val logContext = new LogContext(s"[ReplicaFetcher 
replicaId=$replicaId, leaderId=${sourceBroker.id}, " +
     s"fetcherId=$fetcherId] ")
@@ -90,7 +84,6 @@ class ReplicaFetcherThread(name: String,
     else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_0_10_1_IV2) 1
     else 0
 
-  private val fetchMetadataSupported = brokerConfig.interBrokerProtocolVersion 
>= KAFKA_1_1_IV0
   private val maxWait = brokerConfig.replicaFetchWaitMaxMs
   private val minBytes = brokerConfig.replicaFetchMinBytes
   private val maxBytes = brokerConfig.replicaFetchResponseMaxBytes
@@ -99,7 +92,9 @@ class ReplicaFetcherThread(name: String,
 
   private val fetchSessionHandler = new FetchSessionHandler(logContext, 
sourceBroker.id)
 
-  private def epochCacheOpt(tp: TopicPartition): Option[LeaderEpochCache] =  
replicaMgr.getReplica(tp).map(_.epochs.get)
+  protected def getReplica(tp: TopicPartition): Option[Replica] = {
+    replicaMgr.getReplica(tp)
+  }
 
   override def initiateShutdown(): Boolean = {
     val justShutdown = super.initiateShutdown()
@@ -228,15 +223,9 @@ class ReplicaFetcherThread(name: String,
     }
   }
 
-  // any logic for partitions whose leader has changed
-  def handlePartitionsWithErrors(partitions: Iterable[TopicPartition]) {
-    if (partitions.nonEmpty)
-      delayPartitions(partitions, brokerConfig.replicaFetchBackoffMs.toLong)
-  }
-
-  protected def fetch(fetchRequest: FetchRequest): Seq[(TopicPartition, PD)] = 
{
+  protected def fetch(fetchRequest: FetchRequest.Builder): 
Seq[(TopicPartition, PD)] = {
     try {
-      val clientResponse = leaderEndpoint.sendRequest(fetchRequest.underlying)
+      val clientResponse = leaderEndpoint.sendRequest(fetchRequest)
       val fetchResponse = 
clientResponse.responseBody.asInstanceOf[FetchResponse[Records]]
       if (!fetchSessionHandler.handleResponse(fetchResponse)) {
         Nil
@@ -271,7 +260,7 @@ class ReplicaFetcherThread(name: String,
     }
   }
 
-  override def buildFetchRequest(partitionMap: Seq[(TopicPartition, 
PartitionFetchState)]): ResultWithPartitions[FetchRequest] = {
+  override def buildFetch(partitionMap: Map[TopicPartition, 
PartitionFetchState]): ResultWithPartitions[Option[FetchRequest.Builder]] = {
     val partitionsWithError = mutable.Set[TopicPartition]()
 
     val builder = fetchSessionHandler.newBuilder()
@@ -280,7 +269,7 @@ class ReplicaFetcherThread(name: String,
       if (partitionFetchState.isReadyForFetch && 
!shouldFollowerThrottle(quota, topicPartition)) {
         try {
           val logStartOffset = 
replicaMgr.getReplicaOrException(topicPartition).logStartOffset
-          builder.add(topicPartition, new JFetchRequest.PartitionData(
+          builder.add(topicPartition, new FetchRequest.PartitionData(
             partitionFetchState.fetchOffset, logStartOffset, fetchSize))
         } catch {
           case _: KafkaStorageException =>
@@ -292,63 +281,37 @@ class ReplicaFetcherThread(name: String,
     }
 
     val fetchData = builder.build()
-    val requestBuilder = JFetchRequest.Builder.
-      forReplica(fetchRequestVersion, replicaId, maxWait, minBytes, 
fetchData.toSend())
+    val fetchRequestOpt = if (fetchData.sessionPartitions.isEmpty && 
fetchData.toForget.isEmpty) {
+      None
+    } else {
+      val requestBuilder = FetchRequest.Builder
+        .forReplica(fetchRequestVersion, replicaId, maxWait, minBytes, 
fetchData.toSend)
         .setMaxBytes(maxBytes)
         .toForget(fetchData.toForget)
-    if (fetchMetadataSupported) {
-      requestBuilder.metadata(fetchData.metadata())
+        .metadata(fetchData.metadata)
+      Some(requestBuilder)
     }
-    ResultWithPartitions(new FetchRequest(fetchData.sessionPartitions(), 
requestBuilder), partitionsWithError)
+
+    ResultWithPartitions(fetchRequestOpt, partitionsWithError)
   }
 
   /**
    * Truncate the log for each partition's epoch based on leader's returned 
epoch and offset.
    * The logic for finding the truncation offset is implemented in 
AbstractFetcherThread.getOffsetTruncationState
    */
-  override def maybeTruncate(fetchedEpochs: Map[TopicPartition, 
EpochEndOffset]): ResultWithPartitions[Map[TopicPartition, 
OffsetTruncationState]] = {
-    val fetchOffsets = scala.collection.mutable.HashMap.empty[TopicPartition, 
OffsetTruncationState]
-    val partitionsWithError = mutable.Set[TopicPartition]()
-
-    fetchedEpochs.foreach { case (tp, leaderEpochOffset) =>
-      try {
-        val replica = replicaMgr.getReplicaOrException(tp)
-        val partition = replicaMgr.getPartition(tp).get
-
-        if (leaderEpochOffset.hasError) {
-          info(s"Retrying leaderEpoch request for partition 
${replica.topicPartition} as the leader reported an error: 
${leaderEpochOffset.error}")
-          partitionsWithError += tp
-        } else {
-          val offsetTruncationState = getOffsetTruncationState(tp, 
leaderEpochOffset, replica)
-          if (offsetTruncationState.offset < 
replica.highWatermark.messageOffset)
-            warn(s"Truncating $tp to offset ${offsetTruncationState.offset} 
below high watermark ${replica.highWatermark.messageOffset}")
-
-          partition.truncateTo(offsetTruncationState.offset, isFuture = false)
-          // mark the future replica for truncation only when we do last 
truncation
-          if (offsetTruncationState.truncationCompleted)
-            
replicaMgr.replicaAlterLogDirsManager.markPartitionsForTruncation(brokerConfig.brokerId,
 tp, offsetTruncationState.offset)
-          fetchOffsets.put(tp, offsetTruncationState)
-        }
-      } catch {
-        case e: KafkaStorageException =>
-          info(s"Failed to truncate $tp", e)
-          partitionsWithError += tp
-      }
-    }
-
-    ResultWithPartitions(fetchOffsets, partitionsWithError)
-  }
-
-  override def buildLeaderEpochRequest(allPartitions: Seq[(TopicPartition, 
PartitionFetchState)]): ResultWithPartitions[Map[TopicPartition, Int]] = {
-    val partitionEpochOpts = allPartitions
-      .filter { case (_, state) => state.isTruncatingLog }
-      .map { case (tp, _) => tp -> epochCacheOpt(tp) }.toMap
-
-    val (partitionsWithEpoch, partitionsWithoutEpoch) = 
partitionEpochOpts.partition { case (_, epochCacheOpt) => 
epochCacheOpt.nonEmpty }
-
-    debug(s"Build leaderEpoch request $partitionsWithEpoch")
-    val result = partitionsWithEpoch.map { case (tp, epochCacheOpt) => tp -> 
epochCacheOpt.get.latestEpoch }
-    ResultWithPartitions(result, partitionsWithoutEpoch.keys.toSet)
+  override def truncate(tp: TopicPartition, epochEndOffset: EpochEndOffset): 
OffsetTruncationState = {
+    val replica = replicaMgr.getReplicaOrException(tp)
+    val partition = replicaMgr.getPartition(tp).get
+
+    val offsetTruncationState = getOffsetTruncationState(tp, epochEndOffset, 
replica)
+    if (offsetTruncationState.offset < replica.highWatermark.messageOffset)
+      warn(s"Truncating $tp to offset ${offsetTruncationState.offset} below 
high watermark ${replica.highWatermark.messageOffset}")
+    partition.truncateTo(offsetTruncationState.offset, isFuture = false)
+
+    // mark the future replica for truncation only when we do last truncation
+    if (offsetTruncationState.truncationCompleted)
+      
replicaMgr.replicaAlterLogDirsManager.markPartitionsForTruncation(brokerConfig.brokerId,
 tp, offsetTruncationState.offset)
+    offsetTruncationState
   }
 
   override def fetchEpochsFromLeader(partitions: Map[TopicPartition, Int]): 
Map[TopicPartition, EpochEndOffset] = {
@@ -394,20 +357,7 @@ class ReplicaFetcherThread(name: String,
    *  the quota is exceeded and the replica is not in sync.
    */
   private def shouldFollowerThrottle(quota: ReplicaQuota, topicPartition: 
TopicPartition): Boolean = {
-    val isReplicaInSync = 
fetcherLagStats.isReplicaInSync(topicPartition.topic, topicPartition.partition)
+    val isReplicaInSync = fetcherLagStats.isReplicaInSync(topicPartition)
     quota.isThrottled(topicPartition) && quota.isQuotaExceeded && 
!isReplicaInSync
   }
 }
-
-object ReplicaFetcherThread {
-
-  private[server] class FetchRequest(val sessionParts: 
util.Map[TopicPartition, JFetchRequest.PartitionData],
-                                     val underlying: JFetchRequest.Builder)
-      extends AbstractFetcherThread.FetchRequest {
-    def offset(topicPartition: TopicPartition): Long =
-      sessionParts.get(topicPartition).fetchOffset
-    override def isEmpty = sessionParts.isEmpty && 
underlying.toForget().isEmpty
-    override def toString = underlying.toString
-  }
-
-}
diff --git 
a/core/src/test/scala/integration/kafka/server/ReplicaFetcherThreadFatalErrorTest.scala
 
b/core/src/test/scala/integration/kafka/server/ReplicaFetcherThreadFatalErrorTest.scala
index 2f6db61..6fcf0cc 100644
--- 
a/core/src/test/scala/integration/kafka/server/ReplicaFetcherThreadFatalErrorTest.scala
+++ 
b/core/src/test/scala/integration/kafka/server/ReplicaFetcherThreadFatalErrorTest.scala
@@ -20,7 +20,6 @@ package kafka.server
 import java.util.concurrent.atomic.AtomicBoolean
 
 import kafka.cluster.BrokerEndPoint
-import kafka.server.ReplicaFetcherThread.FetchRequest
 import kafka.utils.{Exit, TestUtils}
 import kafka.utils.TestUtils.createBrokerConfigs
 import kafka.zk.ZooKeeperTestHarness
@@ -29,7 +28,7 @@ import org.apache.kafka.common.internals.FatalExitError
 import org.apache.kafka.common.metrics.Metrics
 import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.record.Records
-import org.apache.kafka.common.requests.FetchResponse
+import org.apache.kafka.common.requests.{FetchRequest, FetchResponse}
 import org.apache.kafka.common.utils.Time
 import org.junit.{After, Test}
 
@@ -89,8 +88,8 @@ class ReplicaFetcherThreadFatalErrorTest extends 
ZooKeeperTestHarness {
       import params._
       new ReplicaFetcherThread(threadName, fetcherId, sourceBroker, config, 
replicaManager, metrics, time, quotaManager) {
         override def handleOffsetOutOfRange(topicPartition: TopicPartition): 
Long = throw new FatalExitError
-        override protected def fetch(fetchRequest: FetchRequest): 
Seq[(TopicPartition, PD)] = {
-          fetchRequest.underlying.fetchData.asScala.keys.toSeq.map { tp =>
+        override protected def fetch(fetchRequest: FetchRequest.Builder): 
Seq[(TopicPartition, PD)] = {
+          fetchRequest.fetchData.asScala.keys.toSeq.map { tp =>
             (tp, new 
FetchResponse.PartitionData[Records](Errors.OFFSET_OUT_OF_RANGE,
               FetchResponse.INVALID_HIGHWATERMARK, 
FetchResponse.INVALID_LAST_STABLE_OFFSET,
               FetchResponse.INVALID_LOG_START_OFFSET, null, null))
diff --git 
a/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala 
b/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala
index db98a87..15abc68 100644
--- a/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala
+++ b/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala
@@ -19,13 +19,12 @@ package kafka.server
 
 import AbstractFetcherThread._
 import com.yammer.metrics.Metrics
-import kafka.cluster.BrokerEndPoint
-import kafka.server.AbstractFetcherThread.FetchRequest
+import kafka.cluster.{BrokerEndPoint, Replica}
 import kafka.utils.TestUtils
 import org.apache.kafka.common.TopicPartition
-import org.apache.kafka.common.protocol.Errors
+import org.apache.kafka.common.protocol.{ApiKeys, Errors}
 import org.apache.kafka.common.record.{CompressionType, MemoryRecords, 
Records, SimpleRecord}
-import org.apache.kafka.common.requests.EpochEndOffset
+import org.apache.kafka.common.requests.{EpochEndOffset, FetchRequest}
 import org.apache.kafka.common.requests.FetchResponse.PartitionData
 import org.junit.Assert.{assertFalse, assertTrue}
 import org.junit.{Before, Test}
@@ -87,19 +86,23 @@ class AbstractFetcherThreadTest {
 
   private def allMetricsNames = 
Metrics.defaultRegistry().allMetrics().asScala.keySet.map(_.getName)
 
-  class DummyFetchRequest(val offsets: collection.Map[TopicPartition, Long]) 
extends FetchRequest {
-    override def isEmpty: Boolean = offsets.isEmpty
-
-    override def offset(topicPartition: TopicPartition): Long = 
offsets(topicPartition)
+  protected def fetchRequestBuilder(partitionMap: 
collection.Map[TopicPartition, PartitionFetchState]): FetchRequest.Builder = {
+    val partitionData = partitionMap.map { case (tp, fetchState) =>
+      tp -> new FetchRequest.PartitionData(fetchState.fetchOffset, 0, 1024 * 
1024)
+    }.toMap.asJava
+    FetchRequest.Builder.forReplica(ApiKeys.FETCH.latestVersion, 0, 0, 1, 
partitionData)
   }
 
   class DummyFetcherThread(name: String,
                            clientId: String,
                            sourceBroker: BrokerEndPoint,
                            fetchBackOffMs: Int = 0)
-    extends AbstractFetcherThread(name, clientId, sourceBroker, 
fetchBackOffMs, isInterruptible = true, includeLogTruncation = false) {
+    extends AbstractFetcherThread(name, clientId, sourceBroker,
+      fetchBackOffMs,
+      isInterruptible = true,
+      includeLogTruncation = false) {
 
-    type REQ = DummyFetchRequest
+    protected def getReplica(tp: TopicPartition): Option[Replica] = None
 
     override def processPartitionData(topicPartition: TopicPartition,
                                       fetchOffset: Long,
@@ -108,27 +111,21 @@ class AbstractFetcherThreadTest {
 
     override def handleOffsetOutOfRange(topicPartition: TopicPartition): Long 
= 0L
 
-    override def handlePartitionsWithErrors(partitions: 
Iterable[TopicPartition]): Unit = {}
-
-    override protected def fetch(fetchRequest: DummyFetchRequest): 
Seq[(TopicPartition, PD)] =
-      fetchRequest.offsets.mapValues(_ => new 
PartitionData[Records](Errors.NONE, 0, 0, 0,
+    override protected def fetch(fetchRequest: FetchRequest.Builder): 
Seq[(TopicPartition, PD)] =
+      fetchRequest.fetchData.asScala.mapValues(_ => new 
PartitionData[Records](Errors.NONE, 0, 0, 0,
         Seq.empty.asJava, MemoryRecords.EMPTY)).toSeq
 
-    override protected def buildFetchRequest(partitionMap: 
collection.Seq[(TopicPartition, PartitionFetchState)]): 
ResultWithPartitions[DummyFetchRequest] =
-      ResultWithPartitions(new DummyFetchRequest(partitionMap.map { case (k, 
v) => (k, v.fetchOffset) }.toMap), Set())
-
-    override def buildLeaderEpochRequest(allPartitions: Seq[(TopicPartition, 
PartitionFetchState)]): ResultWithPartitions[Map[TopicPartition, Int]] = {
-      ResultWithPartitions(Map(), Set())
+    override protected def buildFetch(partitionMap: 
collection.Map[TopicPartition, PartitionFetchState]): 
ResultWithPartitions[Option[FetchRequest.Builder]] = {
+      ResultWithPartitions(Some(fetchRequestBuilder(partitionMap)), Set())
     }
 
     override def fetchEpochsFromLeader(partitions: Map[TopicPartition, Int]): 
Map[TopicPartition, EpochEndOffset] = { Map() }
 
-    override def maybeTruncate(fetchedEpochs: Map[TopicPartition, 
EpochEndOffset]): ResultWithPartitions[Map[TopicPartition, 
OffsetTruncationState]] = {
-      ResultWithPartitions(Map(), Set())
+    override def truncate(tp: TopicPartition, epochEndOffset: EpochEndOffset): 
OffsetTruncationState = {
+      OffsetTruncationState(epochEndOffset.endOffset, truncationCompleted = 
true)
     }
   }
 
-
   @Test
   def testFetchRequestCorruptedMessageException() {
     val partition = new TopicPartition("topic", 0)
@@ -182,7 +179,7 @@ class AbstractFetcherThreadTest {
       }
     }
 
-    override protected def fetch(fetchRequest: DummyFetchRequest): 
Seq[(TopicPartition, PD)] = {
+    override protected def fetch(fetchRequest: FetchRequest.Builder): 
Seq[(TopicPartition, PD)] = {
       fetchCount += 1
       // Set the first fetch to get a corrupted message
       if (fetchCount == 1) {
@@ -193,26 +190,24 @@ class AbstractFetcherThreadTest {
         // flip some bits in the message to ensure the crc fails
         buffer.putInt(15, buffer.getInt(15) ^ 23422)
         buffer.putInt(30, buffer.getInt(30) ^ 93242)
-        fetchRequest.offsets.mapValues(_ => new 
PartitionData[Records](Errors.NONE, 0L, 0L, 0L,
+        fetchRequest.fetchData.asScala.mapValues(_ => new 
PartitionData[Records](Errors.NONE, 0L, 0L, 0L,
           Seq.empty.asJava, records)).toSeq
       } else {
         // Then, the following fetches get the normal data
-        fetchRequest.offsets.mapValues(v => 
normalPartitionDataSet(v.toInt)).toSeq
+        fetchRequest.fetchData.asScala.mapValues(v => 
normalPartitionDataSet(v.fetchOffset.toInt)).toSeq
       }
     }
 
-    override protected def buildFetchRequest(partitionMap: 
collection.Seq[(TopicPartition, PartitionFetchState)]): 
ResultWithPartitions[DummyFetchRequest] = {
+    override protected def buildFetch(partitionMap: 
collection.Map[TopicPartition, PartitionFetchState]): 
ResultWithPartitions[Option[FetchRequest.Builder]] = {
       val requestMap = new mutable.HashMap[TopicPartition, Long]
       partitionMap.foreach { case (topicPartition, partitionFetchState) =>
         // Add backoff delay check
         if (partitionFetchState.isReadyForFetch)
           requestMap.put(topicPartition, partitionFetchState.fetchOffset)
       }
-      ResultWithPartitions(new DummyFetchRequest(requestMap), Set())
+      ResultWithPartitions(Some(fetchRequestBuilder(partitionMap)), Set())
     }
 
-    override def handlePartitionsWithErrors(partitions: 
Iterable[TopicPartition]) = delayPartitions(partitions, fetchBackOffMs.toLong)
-
   }
 
 }
diff --git 
a/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala 
b/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala
index 29a1c9f..8fb5ab6 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala
@@ -479,12 +479,15 @@ class ReplicaAlterLogDirsThreadTest {
       brokerTopicStats = null)
     thread.addPartitions(Map(t1p0 -> 0, t1p1 -> 0))
 
-    val ResultWithPartitions(fetchRequest, partitionsWithError) =
-      thread.buildFetchRequest(Seq((t1p0, new PartitionFetchState(150)), 
(t1p1, new PartitionFetchState(160))))
+    val ResultWithPartitions(fetchRequestOpt, partitionsWithError) = 
thread.buildFetch(Map(
+      t1p0 -> new PartitionFetchState(150),
+      t1p1 -> new PartitionFetchState(160)))
 
-    assertFalse(fetchRequest.isEmpty)
+    assertTrue(fetchRequestOpt.isDefined)
+    val fetchRequest = fetchRequestOpt.get
+    assertFalse(fetchRequest.fetchData.isEmpty)
     assertFalse(partitionsWithError.nonEmpty)
-    val request = fetchRequest.underlying.build()
+    val request = fetchRequest.build()
     assertEquals(0, request.minBytes)
     val fetchInfos = request.fetchData.asScala.toSeq
     assertEquals(1, fetchInfos.length)
@@ -523,37 +526,38 @@ class ReplicaAlterLogDirsThreadTest {
     thread.addPartitions(Map(t1p0 -> 0, t1p1 -> 0))
 
     // one partition is ready and one is truncating
-    val ResultWithPartitions(fetchRequest, partitionsWithError) =
-      thread.buildFetchRequest(Seq(
-        (t1p0, new PartitionFetchState(150)),
-        (t1p1, new PartitionFetchState(160, truncatingLog=true))))
+    val ResultWithPartitions(fetchRequestOpt, partitionsWithError) = 
thread.buildFetch(Map(
+        t1p0 -> new PartitionFetchState(150),
+        t1p1 -> new PartitionFetchState(160, truncatingLog=true)))
 
-    assertFalse(fetchRequest.isEmpty)
+    assertTrue(fetchRequestOpt.isDefined)
+    val fetchRequest = fetchRequestOpt.get
+    assertFalse(fetchRequest.fetchData.isEmpty)
     assertFalse(partitionsWithError.nonEmpty)
-    val fetchInfos = fetchRequest.underlying.build().fetchData.asScala.toSeq
+    val fetchInfos = fetchRequest.build().fetchData.asScala.toSeq
     assertEquals(1, fetchInfos.length)
     assertEquals("Expected fetch request for non-truncating partition", t1p0, 
fetchInfos.head._1)
     assertEquals(150, fetchInfos.head._2.fetchOffset)
 
     // one partition is ready and one is delayed
-    val ResultWithPartitions(fetchRequest2, partitionsWithError2) =
-      thread.buildFetchRequest(Seq(
-        (t1p0, new PartitionFetchState(140)),
-        (t1p1, new PartitionFetchState(160, delay=new DelayedItem(5000)))))
+    val ResultWithPartitions(fetchRequest2Opt, partitionsWithError2) = 
thread.buildFetch(Map(
+        t1p0 -> new PartitionFetchState(140),
+        t1p1 -> new PartitionFetchState(160, delay=new DelayedItem(5000))))
 
-    assertFalse(fetchRequest2.isEmpty)
+    assertTrue(fetchRequest2Opt.isDefined)
+    val fetchRequest2 = fetchRequest2Opt.get
+    assertFalse(fetchRequest2.fetchData.isEmpty)
     assertFalse(partitionsWithError2.nonEmpty)
-    val fetchInfos2 = fetchRequest2.underlying.build().fetchData.asScala.toSeq
+    val fetchInfos2 = fetchRequest2.build().fetchData.asScala.toSeq
     assertEquals(1, fetchInfos2.length)
     assertEquals("Expected fetch request for non-delayed partition", t1p0, 
fetchInfos2.head._1)
     assertEquals(140, fetchInfos2.head._2.fetchOffset)
 
     // both partitions are delayed
-    val ResultWithPartitions(fetchRequest3, partitionsWithError3) =
-      thread.buildFetchRequest(Seq(
-        (t1p0, new PartitionFetchState(140, delay=new DelayedItem(5000))),
-        (t1p1, new PartitionFetchState(160, delay=new DelayedItem(5000)))))
-    assertTrue("Expected no fetch requests since all partitions are delayed", 
fetchRequest3.isEmpty)
+    val ResultWithPartitions(fetchRequest3Opt, partitionsWithError3) = 
thread.buildFetch(Map(
+        t1p0 -> new PartitionFetchState(140, delay=new DelayedItem(5000)),
+        t1p1 -> new PartitionFetchState(160, delay=new DelayedItem(5000))))
+    assertTrue("Expected no fetch requests since all partitions are delayed", 
fetchRequest3Opt.isEmpty)
     assertFalse(partitionsWithError3.nonEmpty)
   }
 

Reply via email to