This is an automated email from the ASF dual-hosted git repository.

mimaison pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 74ebbae8ece KAFKA-19132: Move FetchSession and related classes to 
server module (#20158)
74ebbae8ece is described below

commit 74ebbae8ece464573c1288e8f233ef804074fe7b
Author: Dmitry Werner <[email protected]>
AuthorDate: Thu Jan 29 16:53:34 2026 +0500

    KAFKA-19132: Move FetchSession and related classes to server module (#20158)
    
    
    Reviewers: Mickael Maison <[email protected]>
---
 .../kafka/server/builders/KafkaApisBuilder.java    |   2 +-
 .../src/main/scala/kafka/server/BrokerServer.scala |   9 +-
 .../src/main/scala/kafka/server/FetchSession.scala | 911 ---------------------
 core/src/main/scala/kafka/server/KafkaApis.scala   |   8 +-
 .../scala/unit/kafka/server/FetchSessionTest.scala | 136 +--
 .../scala/unit/kafka/server/KafkaApisTest.scala    |  24 +-
 .../metadata/KRaftMetadataRequestBenchmark.java    |   2 +-
 .../java/org/apache/kafka/server/FetchContext.java | 385 +++++++++
 .../java/org/apache/kafka/server/FetchManager.java | 136 +++
 .../java/org/apache/kafka/server/FetchSession.java | 493 +++++++++++
 .../kafka/server/FetchSessionCacheShard.java       | 297 +++++++
 11 files changed, 1402 insertions(+), 1001 deletions(-)

diff --git a/core/src/main/java/kafka/server/builders/KafkaApisBuilder.java 
b/core/src/main/java/kafka/server/builders/KafkaApisBuilder.java
index e03ab35e90e..3f960e54779 100644
--- a/core/src/main/java/kafka/server/builders/KafkaApisBuilder.java
+++ b/core/src/main/java/kafka/server/builders/KafkaApisBuilder.java
@@ -20,7 +20,6 @@ package kafka.server.builders;
 import kafka.coordinator.transaction.TransactionCoordinator;
 import kafka.network.RequestChannel;
 import kafka.server.AutoTopicCreationManager;
-import kafka.server.FetchManager;
 import kafka.server.ForwardingManager;
 import kafka.server.KafkaApis;
 import kafka.server.KafkaConfig;
@@ -39,6 +38,7 @@ import org.apache.kafka.metadata.MetadataCache;
 import org.apache.kafka.security.DelegationTokenManager;
 import org.apache.kafka.server.ApiVersionManager;
 import org.apache.kafka.server.ClientMetricsManager;
+import org.apache.kafka.server.FetchManager;
 import org.apache.kafka.server.authorizer.Authorizer;
 import org.apache.kafka.storage.log.metrics.BrokerTopicStats;
 
diff --git a/core/src/main/scala/kafka/server/BrokerServer.scala 
b/core/src/main/scala/kafka/server/BrokerServer.scala
index 624d7a6ed89..7a878a6352b 100644
--- a/core/src/main/scala/kafka/server/BrokerServer.scala
+++ b/core/src/main/scala/kafka/server/BrokerServer.scala
@@ -43,6 +43,7 @@ import 
org.apache.kafka.image.publisher.{BrokerRegistrationTracker, MetadataPubl
 import org.apache.kafka.metadata.{BrokerState, KRaftMetadataCache, 
ListenerInfo, MetadataCache, MetadataVersionConfigValidator}
 import org.apache.kafka.metadata.publisher.{AclPublisher, 
DelegationTokenPublisher, DynamicClientQuotaPublisher, 
DynamicTopicClusterQuotaPublisher, ScramPublisher}
 import org.apache.kafka.security.{CredentialProvider, DelegationTokenManager}
+import org.apache.kafka.server.FetchSession.FetchSessionCache
 import org.apache.kafka.server.authorizer.Authorizer
 import org.apache.kafka.server.common.{ApiMessageAndVersion, 
DirectoryEventHandler, NodeToControllerChannelManager, TopicIdPartition}
 import org.apache.kafka.server.config.{ConfigType, 
DelegationTokenManagerConfigs}
@@ -54,7 +55,7 @@ import 
org.apache.kafka.server.share.persister.{DefaultStatePersister, NoOpState
 import org.apache.kafka.server.share.session.ShareSessionCache
 import org.apache.kafka.server.util.timer.{SystemTimer, SystemTimerReaper}
 import org.apache.kafka.server.util.{Deadline, FutureUtils, KafkaScheduler}
-import org.apache.kafka.server.{AssignmentsManager, BrokerFeatures, 
BrokerLifecycleManager, ClientMetricsManager, DefaultApiVersionManager, 
DelayedActionQueue, KRaftTopicCreator, NodeToControllerChannelManagerImpl, 
ProcessRole, RaftControllerNodeProvider}
+import org.apache.kafka.server.{AssignmentsManager, BrokerFeatures, 
BrokerLifecycleManager, ClientMetricsManager, DefaultApiVersionManager, 
DelayedActionQueue, FetchManager, FetchSessionCacheShard, KRaftTopicCreator, 
NodeToControllerChannelManagerImpl, ProcessRole, RaftControllerNodeProvider}
 import org.apache.kafka.server.transaction.AddPartitionsToTxnManager
 import org.apache.kafka.storage.internals.log.LogDirFailureChannel
 import org.apache.kafka.storage.log.metrics.BrokerTopicStats
@@ -422,13 +423,15 @@ class BrokerServer(
       // The FetchSessionCache is divided into config.numIoThreads shards, 
each responsible
       // for Math.max(1, shardNum * sessionIdRange) <= sessionId < (shardNum + 
1) * sessionIdRange
       val sessionIdRange = Int.MaxValue / NumFetchSessionCacheShards
-      val fetchSessionCacheShards = (0 until NumFetchSessionCacheShards)
-        .map(shardNum => new FetchSessionCacheShard(
+      val fetchSessionCacheShards: util.List[FetchSessionCacheShard] = new 
util.ArrayList()
+      for (shardNum <- 0 until NumFetchSessionCacheShards) {
+        fetchSessionCacheShards.add(new FetchSessionCacheShard(
           config.maxIncrementalFetchSessionCacheSlots / 
NumFetchSessionCacheShards,
           KafkaBroker.MIN_INCREMENTAL_FETCH_SESSION_EVICTION_MS,
           sessionIdRange,
           shardNum
         ))
+      }
       val fetchManager = new FetchManager(Time.SYSTEM, new 
FetchSessionCache(fetchSessionCacheShards))
 
       sharePartitionManager = new SharePartitionManager(
diff --git a/core/src/main/scala/kafka/server/FetchSession.scala 
b/core/src/main/scala/kafka/server/FetchSession.scala
deleted file mode 100644
index 4bbb4c47e3f..00000000000
--- a/core/src/main/scala/kafka/server/FetchSession.scala
+++ /dev/null
@@ -1,911 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package kafka.server
-
-import com.typesafe.scalalogging.Logger
-import kafka.utils.Logging
-import org.apache.kafka.common.{Node, TopicIdPartition, TopicPartition, Uuid}
-import org.apache.kafka.common.message.FetchResponseData
-import org.apache.kafka.common.protocol.Errors
-import org.apache.kafka.common.requests.FetchMetadata.{FINAL_EPOCH, 
INITIAL_EPOCH, INVALID_SESSION_ID}
-import org.apache.kafka.common.requests.{FetchRequest, FetchResponse, 
FetchMetadata => JFetchMetadata}
-import org.apache.kafka.common.utils.{ImplicitLinkedHashCollection, Time, 
Utils}
-import org.apache.kafka.server.metrics.KafkaMetricsGroup
-
-import java.util
-import java.util.Optional
-import java.util.concurrent.atomic.AtomicInteger
-import java.util.concurrent.{ThreadLocalRandom, TimeUnit}
-import scala.collection.mutable
-import scala.math.Ordered.orderingToOrdered
-
-object FetchSession {
-  type REQ_MAP = util.Map[TopicIdPartition, FetchRequest.PartitionData]
-  type RESP_MAP = util.LinkedHashMap[TopicIdPartition, 
FetchResponseData.PartitionData]
-  type CACHE_MAP = ImplicitLinkedHashCollection[CachedPartition]
-  type RESP_MAP_ITER = util.Iterator[util.Map.Entry[TopicIdPartition, 
FetchResponseData.PartitionData]]
-  type TOPIC_NAME_MAP = util.Map[Uuid, String]
-
-  val NUM_INCREMENTAL_FETCH_SESSIONS = "NumIncrementalFetchSessions"
-  val NUM_INCREMENTAL_FETCH_PARTITIONS_CACHED = 
"NumIncrementalFetchPartitionsCached"
-  val INCREMENTAL_FETCH_SESSIONS_EVICTIONS_PER_SEC = 
"IncrementalFetchSessionEvictionsPerSec"
-  val EVICTIONS = "evictions"
-
-  def partitionsToLogString(partitions: util.Collection[TopicIdPartition], 
traceEnabled: Boolean): String = {
-    if (traceEnabled) {
-      "(" + String.join(", ", partitions.toString) + ")"
-    } else {
-      s"${partitions.size} partition(s)"
-    }
-  }
-}
-
-/**
-  * A cached partition.
-  *
-  * The broker maintains a set of these objects for each incremental fetch 
session.
-  * When an incremental fetch request is made, any partitions which are not 
explicitly
-  * enumerated in the fetch request are loaded from the cache.  Similarly, 
when an
-  * incremental fetch response is being prepared, any partitions that have not 
changed and
-  * do not have errors are left out of the response.
-  *
-  * We store many of these objects, so it is important for them to be 
memory-efficient.
-  * That is why we store topic and partition separately rather than storing a 
TopicPartition
-  * object.  The TP object takes up more memory because it is a separate JVM 
object, and
-  * because it stores the cached hash code in memory.
-  *
-  * Note that fetcherLogStartOffset is the LSO of the follower performing the 
fetch, whereas
-  * localLogStartOffset is the log start offset of the partition on this 
broker.
-  */
-class CachedPartition(var topic: String,
-                      val topicId: Uuid,
-                      val partition: Int,
-                      var maxBytes: Int,
-                      var fetchOffset: Long,
-                      var highWatermark: Long,
-                      var leaderEpoch: Optional[Integer],
-                      var fetcherLogStartOffset: Long,
-                      var localLogStartOffset: Long,
-                      var lastFetchedEpoch: Optional[Integer])
-    extends ImplicitLinkedHashCollection.Element {
-
-  private var cachedNext: Int = ImplicitLinkedHashCollection.INVALID_INDEX
-  private var cachedPrev: Int = ImplicitLinkedHashCollection.INVALID_INDEX
-
-  override def next: Int = cachedNext
-  override def setNext(next: Int): Unit = this.cachedNext = next
-  override def prev: Int = cachedPrev
-  override def setPrev(prev: Int): Unit = this.cachedPrev = prev
-
-  def this(topic: String, topicId: Uuid, partition: Int) =
-    this(topic, topicId, partition, -1, -1, -1, Optional.empty(), -1, -1, 
Optional.empty[Integer])
-
-  def this(part: TopicIdPartition) = {
-    this(part.topic, part.topicId, part.partition)
-  }
-
-  def this(part: TopicIdPartition, reqData: FetchRequest.PartitionData) =
-    this(part.topic, part.topicId, part.partition, reqData.maxBytes, 
reqData.fetchOffset, -1,
-      reqData.currentLeaderEpoch, reqData.logStartOffset, -1, 
reqData.lastFetchedEpoch)
-
-  def this(part: TopicIdPartition, reqData: FetchRequest.PartitionData,
-           respData: FetchResponseData.PartitionData) =
-    this(part.topic, part.topicId, part.partition, reqData.maxBytes, 
reqData.fetchOffset, respData.highWatermark,
-      reqData.currentLeaderEpoch, reqData.logStartOffset, 
respData.logStartOffset, reqData.lastFetchedEpoch)
-
-  def reqData = new FetchRequest.PartitionData(topicId, fetchOffset, 
fetcherLogStartOffset, maxBytes, leaderEpoch, lastFetchedEpoch)
-
-  def updateRequestParams(reqData: FetchRequest.PartitionData): Unit = {
-    // Update our cached request parameters.
-    maxBytes = reqData.maxBytes
-    fetchOffset = reqData.fetchOffset
-    fetcherLogStartOffset = reqData.logStartOffset
-    leaderEpoch = reqData.currentLeaderEpoch
-    lastFetchedEpoch = reqData.lastFetchedEpoch
-  }
-
-  def maybeResolveUnknownName(topicNames: FetchSession.TOPIC_NAME_MAP): Unit = 
{
-    if (this.topic == null) {
-      this.topic = topicNames.get(this.topicId)
-    }
-  }
-
-  /**
-    * Determine whether or not the specified cached partition should be 
included in the FetchResponse we send back to
-    * the fetcher and update it if requested.
-    *
-    * This function should be called while holding the appropriate session 
lock.
-    *
-    * @param respData partition data
-    * @param updateResponseData if set to true, update this CachedPartition 
with new request and response data.
-    * @return True if this partition should be included in the response; false 
if it can be omitted.
-    */
-  def maybeUpdateResponseData(respData: FetchResponseData.PartitionData, 
updateResponseData: Boolean): Boolean = {
-    // Check the response data.
-    var mustRespond = false
-    if (FetchResponse.recordsSize(respData) > 0) {
-      // Partitions with new data are always included in the response.
-      mustRespond = true
-    }
-    if (highWatermark != respData.highWatermark) {
-      mustRespond = true
-      if (updateResponseData)
-        highWatermark = respData.highWatermark
-    }
-    if (localLogStartOffset != respData.logStartOffset) {
-      mustRespond = true
-      if (updateResponseData)
-        localLogStartOffset = respData.logStartOffset
-    }
-    if (FetchResponse.isPreferredReplica(respData)) {
-      // If the broker computed a preferred read replica, we need to include 
it in the response
-      mustRespond = true
-    }
-    if (respData.errorCode != Errors.NONE.code) {
-      // Partitions with errors are always included in the response.
-      // We also set the cached highWatermark to an invalid offset, -1.
-      // This ensures that when the error goes away, we re-send the partition.
-      if (updateResponseData)
-        highWatermark = -1
-      mustRespond = true
-    }
-
-    if (FetchResponse.isDivergingEpoch(respData)) {
-      // Partitions with diverging epoch are always included in response to 
trigger truncation.
-      mustRespond = true
-    }
-    mustRespond
-  }
-
-  /**
-   * We have different equality checks depending on whether topic IDs are used.
-   * This means we need a different hash function as well. We use name to 
calculate the hash if the ID is zero and unused.
-   * Otherwise, we use the topic ID in the hash calculation.
-   *
-   * @return the hash code for the CachedPartition depending on what request 
version we are using.
-   */
-  override def hashCode: Int =
-    if (topicId != Uuid.ZERO_UUID)
-      (31 * partition) + topicId.hashCode
-    else
-      (31 * partition) + topic.hashCode
-
-  /**
-   * We have different equality checks depending on whether topic IDs are used.
-   *
-   * This is because when we use topic IDs, a partition with a given ID and an 
unknown name is the same as a partition with that
-   * ID and a known name. This means we can only use topic ID and partition 
when determining equality.
-   *
-   * On the other hand, if we are using topic names, all IDs are zero. This 
means we can only use topic name and partition
-   * when determining equality.
-   */
-  override def equals(that: Any): Boolean =
-    that match {
-      case that: CachedPartition =>
-        this.eq(that) || (if (this.topicId != Uuid.ZERO_UUID)
-          this.partition.equals(that.partition) && 
this.topicId.equals(that.topicId)
-        else
-          this.partition.equals(that.partition) && 
this.topic.equals(that.topic))
-      case _ => false
-    }
-
-  override def toString: String = synchronized {
-    "CachedPartition(topic=" + topic +
-      ", topicId=" + topicId +
-      ", partition=" + partition +
-      ", maxBytes=" + maxBytes +
-      ", fetchOffset=" + fetchOffset +
-      ", highWatermark=" + highWatermark +
-      ", fetcherLogStartOffset=" + fetcherLogStartOffset +
-      ", localLogStartOffset=" + localLogStartOffset  +
-        ")"
-  }
-}
-
-/**
-  * The fetch session.
-  *
-  * Each fetch session is protected by its own lock, which must be taken 
before mutable
-  * fields are read or modified.  This includes modification of the session 
partition map.
-  *
-  * @param id                 The unique fetch session ID.
-  * @param privileged         True if this session is privileged.  Sessions 
created by followers
-  *                           are privileged; session created by consumers are 
not.
-  * @param partitionMap       The CachedPartitionMap.
-  * @param usesTopicIds       True if this session is using topic IDs
-  * @param creationMs         The time in milliseconds when this session was 
created.
-  * @param lastUsedMs         The last used time in milliseconds.  This should 
only be updated by
-  *                           FetchSessionCache#touch.
-  * @param epoch              The fetch session sequence number.
-  */
-class FetchSession(val id: Int,
-                   val privileged: Boolean,
-                   val partitionMap: FetchSession.CACHE_MAP,
-                   val usesTopicIds: Boolean,
-                   val creationMs: Long,
-                   var lastUsedMs: Long,
-                   var epoch: Int) {
-  // This is used by the FetchSessionCache to store the last known size of 
this session.
-  // If this is -1, the Session is not in the cache.
-  var cachedSize = -1
-
-  def size: Int = synchronized {
-    partitionMap.size
-  }
-
-  def isEmpty: Boolean = synchronized {
-    partitionMap.isEmpty
-  }
-
-  def lastUsedKey: LastUsedKey = synchronized {
-    LastUsedKey(lastUsedMs, id)
-  }
-
-  def evictableKey: EvictableKey = synchronized {
-    EvictableKey(privileged, cachedSize, id)
-  }
-
-  def metadata: JFetchMetadata = synchronized { new JFetchMetadata(id, epoch) }
-
-  def getFetchOffset(topicIdPartition: TopicIdPartition): Option[Long] = 
synchronized {
-    Option(partitionMap.find(new 
CachedPartition(topicIdPartition))).map(_.fetchOffset)
-  }
-
-  private type TL = util.ArrayList[TopicIdPartition]
-
-  // Update the cached partition data based on the request.
-  def update(fetchData: FetchSession.REQ_MAP,
-             toForget: util.List[TopicIdPartition]): (TL, TL, TL) = 
synchronized {
-    val added = new TL
-    val updated = new TL
-    val removed = new TL
-    fetchData.forEach { (topicPart, reqData) =>
-      val cachedPartitionKey = new CachedPartition(topicPart, reqData)
-      val cachedPart = partitionMap.find(cachedPartitionKey)
-      if (cachedPart == null) {
-        partitionMap.mustAdd(cachedPartitionKey)
-        added.add(topicPart)
-      } else {
-        cachedPart.updateRequestParams(reqData)
-        updated.add(topicPart)
-      }
-    }
-    toForget.forEach { p =>
-      if (partitionMap.remove(new CachedPartition(p))) {
-        removed.add(p)
-      }
-    }
-    (added, updated, removed)
-  }
-
-  override def toString: String = synchronized {
-    "FetchSession(id=" + id +
-      ", privileged=" + privileged +
-      ", partitionMap.size=" + partitionMap.size +
-      ", usesTopicIds=" + usesTopicIds +
-      ", creationMs=" + creationMs +
-      ", lastUsedMs=" + lastUsedMs +
-      ", epoch=" + epoch + ")"
-  }
-}
-
-trait FetchContext extends Logging {
-  /**
-    * Get the fetch offset for a given partition.
-    */
-  def getFetchOffset(part: TopicIdPartition): Option[Long]
-
-  /**
-    * Apply a function to each partition in the fetch request.
-    */
-  def foreachPartition(fun: (TopicIdPartition, FetchRequest.PartitionData) => 
Unit): Unit
-
-  /**
-    * Get the response size to be used for quota computation. Since we are 
returning an empty response in case of
-    * throttling, we are not supposed to update the context until we know that 
we are not going to throttle.
-    */
-  def getResponseSize(updates: FetchSession.RESP_MAP, versionId: Short): Int
-
-  /**
-    * Updates the fetch context with new partition information.  Generates 
response data.
-    * The response data may require subsequent down-conversion.
-    */
-  def updateAndGenerateResponseData(updates: FetchSession.RESP_MAP, 
nodeEndpoints: util.List[Node]): FetchResponse
-
-  def partitionsToLogString(partitions: util.Collection[TopicIdPartition]): 
String =
-    FetchSession.partitionsToLogString(partitions, isTraceEnabled)
-
-  /**
-    * Return an empty throttled response due to quota violation.
-    */
-  def getThrottledResponse(throttleTimeMs: Int, nodeEndpoints: 
util.List[Node]): FetchResponse =
-    FetchResponse.of(Errors.NONE, throttleTimeMs, INVALID_SESSION_ID, new 
FetchSession.RESP_MAP, nodeEndpoints)
-}
-
-/**
-  * The fetch context for a fetch request that had a session error.
-  */
-class SessionErrorContext(val error: Errors,
-                          val reqMetadata: JFetchMetadata) extends 
FetchContext {
-  override def getFetchOffset(part: TopicIdPartition): Option[Long] = None
-
-  override def foreachPartition(fun: (TopicIdPartition, 
FetchRequest.PartitionData) => Unit): Unit = {}
-
-  override def getResponseSize(updates: FetchSession.RESP_MAP, versionId: 
Short): Int = {
-    FetchResponse.sizeOf(versionId, (new 
FetchSession.RESP_MAP).entrySet.iterator)
-  }
-
-  // Because of the fetch session error, we don't know what partitions were 
supposed to be in this request.
-  override def updateAndGenerateResponseData(updates: FetchSession.RESP_MAP, 
nodeEndpoints: util.List[Node]): FetchResponse = {
-    debug(s"Session error fetch context returning $error")
-    FetchResponse.of(error, 0, INVALID_SESSION_ID, new FetchSession.RESP_MAP, 
nodeEndpoints)
-  }
-}
-
-object SessionlessFetchContext {
-  private final val logger = Logger(classOf[SessionlessFetchContext])
-}
-
-/**
-  * The fetch context for a sessionless fetch request.
-  *
-  * @param fetchData          The partition data from the fetch request.
-  */
-class SessionlessFetchContext(val fetchData: util.Map[TopicIdPartition, 
FetchRequest.PartitionData]) extends FetchContext {
-
-  override lazy val logger = SessionlessFetchContext.logger
-
-  override def getFetchOffset(part: TopicIdPartition): Option[Long] =
-    Option(fetchData.get(part)).map(_.fetchOffset)
-
-  override def foreachPartition(fun: (TopicIdPartition, 
FetchRequest.PartitionData) => Unit): Unit = {
-    fetchData.forEach((tp, data) => fun(tp, data))
-  }
-
-  override def getResponseSize(updates: FetchSession.RESP_MAP, versionId: 
Short): Int = {
-    FetchResponse.sizeOf(versionId, updates.entrySet.iterator)
-  }
-
-  override def updateAndGenerateResponseData(updates: FetchSession.RESP_MAP, 
nodeEndpoints: util.List[Node]): FetchResponse = {
-    debug(s"Sessionless fetch context returning 
${partitionsToLogString(updates.keySet)}")
-    FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, updates, 
nodeEndpoints)
-  }
-}
-
-object FullFetchContext {
-  private final val logger = Logger(classOf[FullFetchContext])
-}
-
-/**
-  * The fetch context for a full fetch request.
-  *
-  * @param time               The clock to use.
-  * @param cache              The fetch session cache.
-  * @param reqMetadata        The request metadata.
-  * @param fetchData          The partition data from the fetch request.
-  * @param usesTopicIds       True if this session should use topic IDs.
-  * @param isFromFollower     True if this fetch request came from a follower.
-  */
-class FullFetchContext(private val time: Time,
-                       private val cache: FetchSessionCache,
-                       private val reqMetadata: JFetchMetadata,
-                       private val fetchData: util.Map[TopicIdPartition, 
FetchRequest.PartitionData],
-                       private val usesTopicIds: Boolean,
-                       private val isFromFollower: Boolean) extends 
FetchContext {
-
-  def this(time: Time,
-           cacheShard: FetchSessionCacheShard,
-           reqMetadata: JFetchMetadata,
-           fetchData: util.Map[TopicIdPartition, FetchRequest.PartitionData],
-           usesTopicIds: Boolean,
-           isFromFollower: Boolean
-          ) = this(time, new FetchSessionCache(Seq(cacheShard)), reqMetadata, 
fetchData, usesTopicIds, isFromFollower)
-
-  override lazy val logger = FullFetchContext.logger
-
-  override def getFetchOffset(part: TopicIdPartition): Option[Long] =
-    Option(fetchData.get(part)).map(_.fetchOffset)
-
-  override def foreachPartition(fun: (TopicIdPartition, 
FetchRequest.PartitionData) => Unit): Unit = {
-    fetchData.forEach((tp, data) => fun(tp, data))
-  }
-
-  override def getResponseSize(updates: FetchSession.RESP_MAP, versionId: 
Short): Int = {
-    FetchResponse.sizeOf(versionId, updates.entrySet.iterator)
-  }
-
-  override def updateAndGenerateResponseData(updates: FetchSession.RESP_MAP, 
nodeEndpoints: util.List[Node]): FetchResponse = {
-    def createNewSession: FetchSession.CACHE_MAP = {
-      val cachedPartitions = new FetchSession.CACHE_MAP(updates.size)
-      updates.forEach { (part, respData) =>
-        val reqData = fetchData.get(part)
-        cachedPartitions.mustAdd(new CachedPartition(part, reqData, respData))
-      }
-      cachedPartitions
-    }
-    val cacheShard = cache.getNextCacheShard
-    val responseSessionId = cacheShard.maybeCreateSession(time.milliseconds(), 
isFromFollower,
-        updates.size, usesTopicIds, () => createNewSession)
-    debug(s"Full fetch context with session id $responseSessionId returning " +
-      s"${partitionsToLogString(updates.keySet)}")
-    FetchResponse.of(Errors.NONE, 0, responseSessionId, updates, nodeEndpoints)
-  }
-}
-
-object IncrementalFetchContext {
-  private val logger = Logger(classOf[IncrementalFetchContext])
-}
-
-/**
-  * The fetch context for an incremental fetch request.
-  *
-  * @param time         The clock to use.
-  * @param reqMetadata  The request metadata.
-  * @param session      The incremental fetch request session.
-  * @param topicNames   A mapping from topic ID to topic name used to resolve 
partitions already in the session.
-  */
-class IncrementalFetchContext(private val time: Time,
-                              private val reqMetadata: JFetchMetadata,
-                              private val session: FetchSession,
-                              private val topicNames: 
FetchSession.TOPIC_NAME_MAP) extends FetchContext {
-
-  override lazy val logger = IncrementalFetchContext.logger
-
-  override def getFetchOffset(tp: TopicIdPartition): Option[Long] = 
session.getFetchOffset(tp)
-
-  override def foreachPartition(fun: (TopicIdPartition, 
FetchRequest.PartitionData) => Unit): Unit = {
-    // Take the session lock and iterate over all the cached partitions.
-    session.synchronized {
-      session.partitionMap.forEach { part =>
-        // Try to resolve an unresolved partition if it does not yet have a 
name
-        if (session.usesTopicIds)
-          part.maybeResolveUnknownName(topicNames)
-        fun(new TopicIdPartition(part.topicId, new TopicPartition(part.topic, 
part.partition)), part.reqData)
-      }
-    }
-  }
-
-  // Iterator that goes over the given partition map and selects partitions 
that need to be included in the response.
-  // If updateFetchContextAndRemoveUnselected is set to true, the fetch 
context will be updated for the selected
-  // partitions and also remove unselected ones as they are encountered.
-  private class PartitionIterator(val iter: FetchSession.RESP_MAP_ITER,
-                                  val updateFetchContextAndRemoveUnselected: 
Boolean)
-    extends FetchSession.RESP_MAP_ITER {
-    private var nextElement: util.Map.Entry[TopicIdPartition, 
FetchResponseData.PartitionData] = _
-
-    override def hasNext: Boolean = {
-      while ((nextElement == null) && iter.hasNext) {
-        val element = iter.next()
-        val topicPart = element.getKey
-        val respData = element.getValue
-        val cachedPart = session.partitionMap.find(new 
CachedPartition(topicPart))
-        val mustRespond = cachedPart.maybeUpdateResponseData(respData, 
updateFetchContextAndRemoveUnselected)
-        if (mustRespond) {
-          nextElement = element
-          if (updateFetchContextAndRemoveUnselected && 
FetchResponse.recordsSize(respData) > 0) {
-            session.partitionMap.remove(cachedPart)
-            session.partitionMap.mustAdd(cachedPart)
-          }
-        } else {
-          if (updateFetchContextAndRemoveUnselected) {
-            iter.remove()
-          }
-        }
-      }
-      nextElement != null
-    }
-
-    override def next(): util.Map.Entry[TopicIdPartition, 
FetchResponseData.PartitionData] = {
-      if (!hasNext) throw new NoSuchElementException
-      val element = nextElement
-      nextElement = null
-      element
-    }
-
-    override def remove(): Unit = throw new UnsupportedOperationException
-  }
-
-  override def getResponseSize(updates: FetchSession.RESP_MAP, versionId: 
Short): Int = {
-    session.synchronized {
-      val expectedEpoch = JFetchMetadata.nextEpoch(reqMetadata.epoch)
-      if (session.epoch != expectedEpoch) {
-        FetchResponse.sizeOf(versionId, (new 
FetchSession.RESP_MAP).entrySet.iterator)
-      } else {
-        // Pass the partition iterator which updates neither the fetch context 
nor the partition map.
-        FetchResponse.sizeOf(versionId, new 
PartitionIterator(updates.entrySet.iterator, false))
-      }
-    }
-  }
-
-  override def updateAndGenerateResponseData(updates: FetchSession.RESP_MAP, 
nodeEndpoints: util.List[Node]): FetchResponse = {
-    session.synchronized {
-      // Check to make sure that the session epoch didn't change in between
-      // creating this fetch context and generating this response.
-      val expectedEpoch = JFetchMetadata.nextEpoch(reqMetadata.epoch)
-      if (session.epoch != expectedEpoch) {
-        info(s"Incremental fetch session ${session.id} expected epoch 
$expectedEpoch, but " +
-          s"got ${session.epoch}.  Possible duplicate request.")
-        FetchResponse.of(Errors.INVALID_FETCH_SESSION_EPOCH, 0, session.id, 
new FetchSession.RESP_MAP, nodeEndpoints)
-      } else {
-        // Iterate over the update list using PartitionIterator. This will 
prune updates which don't need to be sent
-        val partitionIter = new PartitionIterator(updates.entrySet.iterator, 
true)
-        while (partitionIter.hasNext) {
-          partitionIter.next()
-        }
-        debug(s"Incremental fetch context with session id ${session.id} 
returning " +
-          s"${partitionsToLogString(updates.keySet)}")
-        FetchResponse.of(Errors.NONE, 0, session.id, updates, nodeEndpoints)
-      }
-    }
-  }
-
-  override def getThrottledResponse(throttleTimeMs: Int, nodeEndpoints: 
util.List[Node]): FetchResponse = {
-    session.synchronized {
-      // Check to make sure that the session epoch didn't change in between
-      // creating this fetch context and generating this response.
-      val expectedEpoch = JFetchMetadata.nextEpoch(reqMetadata.epoch)
-      if (session.epoch != expectedEpoch) {
-        info(s"Incremental fetch session ${session.id} expected epoch 
$expectedEpoch, but " +
-          s"got ${session.epoch}.  Possible duplicate request.")
-        FetchResponse.of(Errors.INVALID_FETCH_SESSION_EPOCH, throttleTimeMs, 
session.id, new FetchSession.RESP_MAP, nodeEndpoints)
-      } else {
-        FetchResponse.of(Errors.NONE, throttleTimeMs, session.id, new 
FetchSession.RESP_MAP, nodeEndpoints)
-      }
-    }
-  }
-}
-
-case class LastUsedKey(lastUsedMs: Long, id: Int) extends 
Comparable[LastUsedKey] {
-  override def compareTo(other: LastUsedKey): Int =
-    (lastUsedMs, id) compare (other.lastUsedMs, other.id)
-}
-
-case class EvictableKey(privileged: Boolean, size: Int, id: Int) extends 
Comparable[EvictableKey] {
-  override def compareTo(other: EvictableKey): Int =
-    (privileged, size, id) compare (other.privileged, other.size, other.id)
-}
-
-
-/**
-  * Caches fetch sessions.
-  *
-  * See tryEvict for an explanation of the cache eviction strategy.
-  *
-  * The FetchSessionCache is thread-safe because all of its methods are 
synchronized.
-  * Note that individual fetch sessions have their own locks which are 
separate from the
-  * FetchSessionCache lock.  In order to avoid deadlock, the FetchSessionCache 
lock
-  * must never be acquired while an individual FetchSession lock is already 
held.
-  *
-  * @param maxEntries The maximum number of entries that can be in the cache.
-  * @param evictionMs The minimum time that an entry must be unused in order 
to be evictable.
-  * @param sessionIdRange The number of sessionIds each cache shard handles. 
For a given instance, Math.max(1, shardNum * sessionIdRange) <= sessionId < 
(shardNum + 1) * sessionIdRange always holds.
-  * @param shardNum Identifier for this shard.
- */
-class FetchSessionCacheShard(private val maxEntries: Int,
-                             private val evictionMs: Long,
-                             val sessionIdRange: Int = Int.MaxValue,
-                             private val shardNum: Int = 0) extends Logging {
-
-  this.logIdent = s"[Shard $shardNum] "
-
-  private var numPartitions: Long = 0
-
-  // A map of session ID to FetchSession.
-  private val sessions = new mutable.HashMap[Int, FetchSession]
-
-  // Maps last used times to sessions.
-  private val lastUsed = new util.TreeMap[LastUsedKey, FetchSession]
-
-  // A map containing sessions which can be evicted by both privileged and
-  // unprivileged sessions.
-  private val evictableByAll = new util.TreeMap[EvictableKey, FetchSession]
-
-  // A map containing sessions which can be evicted by privileged sessions.
-  private val evictableByPrivileged = new util.TreeMap[EvictableKey, 
FetchSession]
-
-  // This metric is shared across all shards because newMeter returns an 
existing metric
-  // if one exists with the same name. It's safe for concurrent use because 
Meter is thread-safe.
-  private[server] val evictionsMeter = 
FetchSessionCache.metricsGroup.newMeter(FetchSession.INCREMENTAL_FETCH_SESSIONS_EVICTIONS_PER_SEC,
-    FetchSession.EVICTIONS, TimeUnit.SECONDS)
-
-  /**
-    * Get a session by session ID.
-    *
-    * @param sessionId  The session ID.
-    * @return           The session, or None if no such session was found.
-    */
-  def get(sessionId: Int): Option[FetchSession] = synchronized {
-    sessions.get(sessionId)
-  }
-
-  /**
-    * Get the number of entries currently in the fetch session cache.
-    */
-  def size: Int = synchronized {
-    sessions.size
-  }
-
-  /**
-    * Get the total number of cached partitions.
-    */
-  def totalPartitions: Long = synchronized {
-    numPartitions
-  }
-
-  /**
-    * Creates a new random session ID.  The new session ID will be positive 
and unique on this broker.
-    *
-    * @return   The new session ID.
-    */
-  def newSessionId(): Int = synchronized {
-    var id = 0
-    do {
-      id = ThreadLocalRandom.current().nextInt(Math.max(1, shardNum * 
sessionIdRange), (shardNum + 1) * sessionIdRange)
-    } while (sessions.contains(id) || id == INVALID_SESSION_ID)
-    id
-  }
-
-  /**
-    * Try to create a new session.
-    *
-    * @param now                The current time in milliseconds.
-    * @param privileged         True if the new entry we are trying to create 
is privileged.
-    * @param size               The number of cached partitions in the new 
entry we are trying to create.
-    * @param usesTopicIds       True if this session should use topic IDs.
-    * @param createPartitions   A callback function which creates the map of 
cached partitions and the mapping from
-    *                           topic name to topic ID for the topics.
-    * @return                   If we created a session, the ID; 
INVALID_SESSION_ID otherwise.
-    */
-  def maybeCreateSession(now: Long,
-                         privileged: Boolean,
-                         size: Int,
-                         usesTopicIds: Boolean,
-                         createPartitions: () => FetchSession.CACHE_MAP): Int =
-  synchronized {
-    // If there is room, create a new session entry.
-    if ((sessions.size < maxEntries) ||
-        tryEvict(privileged, EvictableKey(privileged, size, 0), now)) {
-      val partitionMap = createPartitions()
-      val session = new FetchSession(newSessionId(), privileged, partitionMap, 
usesTopicIds,
-          now, now, JFetchMetadata.nextEpoch(INITIAL_EPOCH))
-      debug(s"Created fetch session ${session.toString}")
-      sessions.put(session.id, session)
-      touch(session, now)
-      session.id
-    } else {
-      debug(s"No fetch session created for privileged=$privileged, 
size=$size.")
-      INVALID_SESSION_ID
-    }
-  }
-
-  /**
-    * Try to evict an entry from the session cache.
-    *
-    * A proposed new element A may evict an existing element B if:
-    * 1. A is privileged and B is not, or
-    * 2. B is considered "stale" because it has been inactive for a long time, 
or
-    * 3. A contains more partitions than B, and B is not recently created.
-    *
-    * Prior to KAFKA-9401, the session cache was not sharded and we looked at 
all
-    * entries while considering those eligible for eviction. Now eviction is 
done
-    * by considering entries on a per-shard basis.
-    *
-    * @param privileged True if the new entry we would like to add is 
privileged.
-    * @param key        The EvictableKey for the new entry we would like to 
add.
-    * @param now        The current time in milliseconds.
-    * @return           True if an entry was evicted; false otherwise.
-    */
-  private def tryEvict(privileged: Boolean, key: EvictableKey, now: Long): 
Boolean = synchronized {
-    // Try to evict an entry which is stale.
-    val lastUsedEntry = lastUsed.firstEntry
-    if (lastUsedEntry == null) {
-      trace("There are no cache entries to evict.")
-      false
-    } else if (now - lastUsedEntry.getKey.lastUsedMs > evictionMs) {
-      val session = lastUsedEntry.getValue
-      trace(s"Evicting stale FetchSession ${session.id}.")
-      remove(session)
-      evictionsMeter.mark()
-      true
-    } else {
-      // If there are no stale entries, check the first evictable entry.
-      // If it is less valuable than our proposed entry, evict it.
-      val map = if (privileged) evictableByPrivileged else evictableByAll
-      val evictableEntry = map.firstEntry
-      if (evictableEntry == null) {
-        trace("No evictable entries found.")
-        false
-      } else if (key.compareTo(evictableEntry.getKey) < 0) {
-        trace(s"Can't evict ${evictableEntry.getKey} with ${key.toString}")
-        false
-      } else {
-        trace(s"Evicting ${evictableEntry.getKey} with ${key.toString}.")
-        remove(evictableEntry.getValue)
-        evictionsMeter.mark()
-        true
-      }
-    }
-  }
-
-  def remove(sessionId: Int): Option[FetchSession] = synchronized {
-    get(sessionId) match {
-      case None => None
-      case Some(session) => remove(session)
-    }
-  }
-
-  /**
-    * Remove an entry from the session cache.
-    *
-    * @param session  The session.
-    *
-    * @return         The removed session, or None if there was no such 
session.
-    */
-  def remove(session: FetchSession): Option[FetchSession] = synchronized {
-    val evictableKey = session.synchronized {
-      lastUsed.remove(session.lastUsedKey)
-      session.evictableKey
-    }
-    evictableByAll.remove(evictableKey)
-    evictableByPrivileged.remove(evictableKey)
-    val removeResult = sessions.remove(session.id)
-    if (removeResult.isDefined) {
-      numPartitions = numPartitions - session.cachedSize
-    }
-    removeResult
-  }
-
-  /**
-    * Update a session's position in the lastUsed and evictable trees.
-    *
-    * @param session  The session.
-    * @param now      The current time in milliseconds.
-    */
-  def touch(session: FetchSession, now: Long): Unit = synchronized {
-    session.synchronized {
-      // Update the lastUsed map.
-      lastUsed.remove(session.lastUsedKey)
-      session.lastUsedMs = now
-      lastUsed.put(session.lastUsedKey, session)
-
-      val oldSize = session.cachedSize
-      if (oldSize != -1) {
-        val oldEvictableKey = session.evictableKey
-        evictableByPrivileged.remove(oldEvictableKey)
-        evictableByAll.remove(oldEvictableKey)
-        numPartitions = numPartitions - oldSize
-      }
-      session.cachedSize = session.size
-      val newEvictableKey = session.evictableKey
-      if ((!session.privileged) || (now - session.creationMs > evictionMs)) {
-        evictableByPrivileged.put(newEvictableKey, session)
-      }
-      if (now - session.creationMs > evictionMs) {
-        evictableByAll.put(newEvictableKey, session)
-      }
-      numPartitions = numPartitions + session.cachedSize
-    }
-  }
-}
-object FetchSessionCache {
-  // Changing the package or class name may cause incompatibility with 
existing code and metrics configuration
-  private val metricsPackage = "kafka.server"
-  private val metricsClassName = "FetchSessionCache"
-  private[server] val metricsGroup = new KafkaMetricsGroup(metricsPackage, 
metricsClassName)
-  private[server] val counter = new AtomicInteger(0)
-}
-
-class FetchSessionCache(private val cacheShards: Seq[FetchSessionCacheShard]) {
-  // Set up metrics.
-  
FetchSessionCache.metricsGroup.newGauge(FetchSession.NUM_INCREMENTAL_FETCH_SESSIONS,
 () => cacheShards.map(_.size).sum)
-  
FetchSessionCache.metricsGroup.newGauge(FetchSession.NUM_INCREMENTAL_FETCH_PARTITIONS_CACHED,
 () => cacheShards.map(_.totalPartitions).sum)
-
-  def getCacheShard(sessionId: Int): FetchSessionCacheShard = {
-    val shard = sessionId / cacheShards.head.sessionIdRange
-    // This assumes that cacheShards is sorted by shardNum
-    cacheShards(shard)
-  }
-
-  // Returns the shard in round-robin
-  def getNextCacheShard: FetchSessionCacheShard = {
-    val shardNum = 
Utils.toPositive(FetchSessionCache.counter.getAndIncrement()) % size
-    cacheShards(shardNum)
-  }
-
-  def size: Int = {
-    cacheShards.size
-  }
-}
-
-class FetchManager(private val time: Time,
-                   private val cache: FetchSessionCache) extends Logging {
-
-  def this(time: Time, cacheShard: FetchSessionCacheShard) = this(time, new 
FetchSessionCache(Seq(cacheShard)))
-
-  def newContext(reqVersion: Short,
-                 reqMetadata: JFetchMetadata,
-                 isFollower: Boolean,
-                 fetchData: FetchSession.REQ_MAP,
-                 toForget: util.List[TopicIdPartition],
-                 topicNames: FetchSession.TOPIC_NAME_MAP): FetchContext = {
-    val context = if (reqMetadata.isFull) {
-      var removedFetchSessionStr = ""
-      if (reqMetadata.sessionId != INVALID_SESSION_ID) {
-        val cacheShard = cache.getCacheShard(reqMetadata.sessionId())
-        // Any session specified in a FULL fetch request will be closed.
-        if (cacheShard.remove(reqMetadata.sessionId).isDefined) {
-          removedFetchSessionStr = s" Removed fetch session 
${reqMetadata.sessionId}."
-        }
-      }
-      var suffix = ""
-      val context = if (reqMetadata.epoch == FINAL_EPOCH) {
-        // If the epoch is FINAL_EPOCH, don't try to create a new session.
-        suffix = " Will not try to create a new session."
-        new SessionlessFetchContext(fetchData)
-      } else {
-        new FullFetchContext(time, cache, reqMetadata, fetchData, reqVersion 
>= 13, isFollower)
-      }
-      debug(s"Created a new full FetchContext with 
${partitionsToLogString(fetchData.keySet)}."+
-        s"$removedFetchSessionStr$suffix")
-      context
-    } else {
-      val cacheShard = cache.getCacheShard(reqMetadata.sessionId())
-      cacheShard.synchronized {
-        cacheShard.get(reqMetadata.sessionId) match {
-          case None => {
-            debug(s"Session error for ${reqMetadata.sessionId}: no such 
session ID found.")
-            new SessionErrorContext(Errors.FETCH_SESSION_ID_NOT_FOUND, 
reqMetadata)
-          }
-          case Some(session) => session.synchronized {
-            if (session.epoch != reqMetadata.epoch) {
-              debug(s"Session error for ${reqMetadata.sessionId}: expected 
epoch " +
-                s"${session.epoch}, but got ${reqMetadata.epoch} instead.")
-              new SessionErrorContext(Errors.INVALID_FETCH_SESSION_EPOCH, 
reqMetadata)
-            } else if (session.usesTopicIds && reqVersion < 13 || 
!session.usesTopicIds && reqVersion >= 13)  {
-              debug(s"Session error for ${reqMetadata.sessionId}: expected  " +
-                s"${if (session.usesTopicIds) "to use topic IDs" else "to not 
use topic IDs"}" +
-                s", but request version $reqVersion means that we can not.")
-              new SessionErrorContext(Errors.FETCH_SESSION_TOPIC_ID_ERROR, 
reqMetadata)
-            } else {
-              val (added, updated, removed) = session.update(fetchData, 
toForget)
-              if (session.isEmpty) {
-                debug(s"Created a new sessionless FetchContext and closing 
session id ${session.id}, " +
-                  s"epoch ${session.epoch}: after removing 
${partitionsToLogString(removed)}, " +
-                  s"there are no more partitions left.")
-                cacheShard.remove(session)
-                new SessionlessFetchContext(fetchData)
-              } else {
-                cacheShard.touch(session, time.milliseconds())
-                session.epoch = JFetchMetadata.nextEpoch(session.epoch)
-                debug(s"Created a new incremental FetchContext for session id 
${session.id}, " +
-                  s"epoch ${session.epoch}: added 
${partitionsToLogString(added)}, " +
-                  s"updated ${partitionsToLogString(updated)}, " +
-                  s"removed ${partitionsToLogString(removed)}")
-                new IncrementalFetchContext(time, reqMetadata, session, 
topicNames)
-              }
-            }
-          }
-        }
-      }
-    }
-    context
-  }
-
-  private def partitionsToLogString(partitions: 
util.Collection[TopicIdPartition]): String =
-    FetchSession.partitionsToLogString(partitions, isTraceEnabled)
-}
diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala 
b/core/src/main/scala/kafka/server/KafkaApis.scala
index 7afd5559b7c..e226ebfcbca 100644
--- a/core/src/main/scala/kafka/server/KafkaApis.scala
+++ b/core/src/main/scala/kafka/server/KafkaApis.scala
@@ -61,7 +61,7 @@ import org.apache.kafka.coordinator.group.{Group, 
GroupConfig, GroupConfigManage
 import org.apache.kafka.coordinator.share.ShareCoordinator
 import org.apache.kafka.metadata.{ConfigRepository, MetadataCache}
 import org.apache.kafka.security.DelegationTokenManager
-import org.apache.kafka.server.{ApiVersionManager, ClientMetricsManager, 
ProcessRole}
+import org.apache.kafka.server.{ApiVersionManager, ClientMetricsManager, 
FetchManager, ProcessRole}
 import org.apache.kafka.server.authorizer._
 import org.apache.kafka.server.common.{GroupVersion, RequestLocal, 
ShareVersion, StreamsVersion, TransactionVersion}
 import org.apache.kafka.server.share.context.ShareFetchContext
@@ -2823,11 +2823,11 @@ class KafkaApis(val requestChannel: RequestChannel,
               val timeoutMs = heartbeatIntervalMs * 2
 
               
autoTopicCreationManager.createStreamsInternalTopics(topicsToCreate, 
requestContext, timeoutMs)
-              
+
               // Check for cached topic creation errors only if there's 
already a MISSING_INTERNAL_TOPICS status
-              val hasMissingInternalTopicsStatus = responseData.status() != 
null && 
+              val hasMissingInternalTopicsStatus = responseData.status() != 
null &&
                 responseData.status().stream().anyMatch(s => s.statusCode() == 
StreamsGroupHeartbeatResponse.Status.MISSING_INTERNAL_TOPICS.code())
-              
+
               if (hasMissingInternalTopicsStatus) {
                 val currentTimeMs = time.milliseconds()
                 val cachedErrors = 
autoTopicCreationManager.getStreamsInternalTopicCreationErrors(topicsToCreate.keys.toSet,
 currentTimeMs)
diff --git a/core/src/test/scala/unit/kafka/server/FetchSessionTest.scala 
b/core/src/test/scala/unit/kafka/server/FetchSessionTest.scala
index 8e4df446d88..47ec4480c31 100755
--- a/core/src/test/scala/unit/kafka/server/FetchSessionTest.scala
+++ b/core/src/test/scala/unit/kafka/server/FetchSessionTest.scala
@@ -24,6 +24,10 @@ import org.apache.kafka.common.record.MemoryRecords
 import org.apache.kafka.common.record.SimpleRecord
 import org.apache.kafka.common.requests.FetchMetadata.{FINAL_EPOCH, 
INVALID_SESSION_ID}
 import org.apache.kafka.common.requests.{FetchRequest, FetchResponse, 
FetchMetadata => JFetchMetadata}
+import org.apache.kafka.common.utils.ImplicitLinkedHashCollection
+import org.apache.kafka.server.FetchContext.{FullFetchContext, 
IncrementalFetchContext, SessionErrorContext, SessionlessFetchContext}
+import org.apache.kafka.server.{FetchContext, FetchManager, FetchSession, 
FetchSessionCacheShard}
+import org.apache.kafka.server.FetchSession.{CachedPartition, 
FetchSessionCache}
 import org.apache.kafka.server.util.MockTime
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, Test, Timeout}
@@ -40,15 +44,15 @@ class FetchSessionTest {
 
   @AfterEach
   def afterEach(): Unit = {
-    
FetchSessionCache.metricsGroup.removeMetric(FetchSession.NUM_INCREMENTAL_FETCH_SESSIONS)
-    
FetchSessionCache.metricsGroup.removeMetric(FetchSession.NUM_INCREMENTAL_FETCH_PARTITIONS_CACHED)
-    
FetchSessionCache.metricsGroup.removeMetric(FetchSession.INCREMENTAL_FETCH_SESSIONS_EVICTIONS_PER_SEC)
-    FetchSessionCache.counter.set(0)
+    
FetchSessionCache.METRICS_GROUP.removeMetric(FetchSession.NUM_INCREMENTAL_FETCH_SESSIONS)
+    
FetchSessionCache.METRICS_GROUP.removeMetric(FetchSession.NUM_INCREMENTAL_FETCH_PARTITIONS_CACHED)
+    
FetchSessionCache.METRICS_GROUP.removeMetric(FetchSession.INCREMENTAL_FETCH_SESSIONS_EVICTIONS_PER_SEC)
+    FetchSessionCache.COUNTER.set(0)
   }
 
   @Test
   def testNewSessionId(): Unit = {
-    val cacheShard = new FetchSessionCacheShard(3, 100)
+    val cacheShard = new FetchSessionCacheShard(3, 100, Int.MaxValue, 0)
     for (_ <- 0 to 10000) {
       val id = cacheShard.newSessionId()
       assertTrue(id > 0)
@@ -59,14 +63,14 @@ class FetchSessionTest {
     var i = 0
     for (sessionId <- sessionIds) {
       i = i + 1
-      assertTrue(cacheShard.get(sessionId).isDefined,
+      assertTrue(cacheShard.get(sessionId).isPresent,
         s"Missing session $i out of ${sessionIds.size} ($sessionId)")
     }
     assertEquals(sessionIds.size, cacheShard.size)
   }
 
-  private def dummyCreate(size: Int): FetchSession.CACHE_MAP = {
-    val cacheMap = new FetchSession.CACHE_MAP(size)
+  private def dummyCreate(size: Int): 
ImplicitLinkedHashCollection[CachedPartition] = {
+    val cacheMap = new ImplicitLinkedHashCollection[CachedPartition](size)
     for (i <- 0 until size) {
       cacheMap.add(new CachedPartition("test", Uuid.randomUuid(), i))
     }
@@ -75,34 +79,34 @@ class FetchSessionTest {
 
   @Test
   def testSessionCache(): Unit = {
-    val cacheShard = new FetchSessionCacheShard(3, 100)
+    val cacheShard = new FetchSessionCacheShard(3, 100, Int.MaxValue, 0)
     assertEquals(0, cacheShard.size)
-    val id1 = cacheShard.maybeCreateSession(0, privileged = false, 10, 
usesTopicIds = true, () => dummyCreate(10))
-    val id2 = cacheShard.maybeCreateSession(10, privileged = false, 20, 
usesTopicIds = true, () => dummyCreate(20))
-    val id3 = cacheShard.maybeCreateSession(20, privileged = false, 30, 
usesTopicIds = true, () => dummyCreate(30))
-    assertEquals(INVALID_SESSION_ID, cacheShard.maybeCreateSession(30, 
privileged = false, 40, usesTopicIds = true, () => dummyCreate(40)))
-    assertEquals(INVALID_SESSION_ID, cacheShard.maybeCreateSession(40, 
privileged = false, 5, usesTopicIds = true, () => dummyCreate(5)))
+    val id1 = cacheShard.maybeCreateSession(0, false, 10, true, () => 
dummyCreate(10))
+    val id2 = cacheShard.maybeCreateSession(10, false, 20, true, () => 
dummyCreate(20))
+    val id3 = cacheShard.maybeCreateSession(20, false, 30, true, () => 
dummyCreate(30))
+    assertEquals(INVALID_SESSION_ID, cacheShard.maybeCreateSession(30, false, 
40, true, () => dummyCreate(40)))
+    assertEquals(INVALID_SESSION_ID, cacheShard.maybeCreateSession(40, false, 
5, true, () => dummyCreate(5)))
     assertCacheContains(cacheShard, id1, id2, id3)
     cacheShard.touch(cacheShard.get(id1).get, 200)
-    val id4 = cacheShard.maybeCreateSession(210, privileged = false, 11, 
usesTopicIds = true, () => dummyCreate(11))
+    val id4 = cacheShard.maybeCreateSession(210, false, 11, true, () => 
dummyCreate(11))
     assertCacheContains(cacheShard, id1, id3, id4)
     cacheShard.touch(cacheShard.get(id1).get, 400)
     cacheShard.touch(cacheShard.get(id3).get, 390)
     cacheShard.touch(cacheShard.get(id4).get, 400)
-    val id5 = cacheShard.maybeCreateSession(410, privileged = false, 50, 
usesTopicIds = true, () => dummyCreate(50))
+    val id5 = cacheShard.maybeCreateSession(410, false, 50, true, () => 
dummyCreate(50))
     assertCacheContains(cacheShard, id3, id4, id5)
-    assertEquals(INVALID_SESSION_ID, cacheShard.maybeCreateSession(410, 
privileged = false, 5, usesTopicIds = true, () => dummyCreate(5)))
-    val id6 = cacheShard.maybeCreateSession(410, privileged = true, 5, 
usesTopicIds = true, () => dummyCreate(5))
+    assertEquals(INVALID_SESSION_ID, cacheShard.maybeCreateSession(410, false, 
5, true, () => dummyCreate(5)))
+    val id6 = cacheShard.maybeCreateSession(410, true, 5, true, () => 
dummyCreate(5))
     assertCacheContains(cacheShard, id3, id5, id6)
   }
 
   @Test
   def testResizeCachedSessions(): Unit = {
-    val cacheShard = new FetchSessionCacheShard(2, 100)
+    val cacheShard = new FetchSessionCacheShard(2, 100, Int.MaxValue, 0)
     assertEquals(0, cacheShard.totalPartitions)
     assertEquals(0, cacheShard.size)
     assertEquals(0, cacheShard.evictionsMeter.count)
-    val id1 = cacheShard.maybeCreateSession(0, privileged = false, 2, 
usesTopicIds = true, () => dummyCreate(2))
+    val id1 = cacheShard.maybeCreateSession(0, false, 2, true, () => 
dummyCreate(2))
     assertTrue(id1 > 0)
     assertCacheContains(cacheShard, id1)
     val session1 = cacheShard.get(id1).get
@@ -110,7 +114,7 @@ class FetchSessionTest {
     assertEquals(2, cacheShard.totalPartitions)
     assertEquals(1, cacheShard.size)
     assertEquals(0, cacheShard.evictionsMeter.count)
-    val id2 = cacheShard.maybeCreateSession(0, privileged = false, 4, 
usesTopicIds = true, () => dummyCreate(4))
+    val id2 = cacheShard.maybeCreateSession(0, false, 4, true, () => 
dummyCreate(4))
     val session2 = cacheShard.get(id2).get
     assertTrue(id2 > 0)
     assertCacheContains(cacheShard, id1, id2)
@@ -119,7 +123,7 @@ class FetchSessionTest {
     assertEquals(0, cacheShard.evictionsMeter.count)
     cacheShard.touch(session1, 200)
     cacheShard.touch(session2, 200)
-    val id3 = cacheShard.maybeCreateSession(200, privileged = false, 5, 
usesTopicIds = true, () => dummyCreate(5))
+    val id3 = cacheShard.maybeCreateSession(200, false, 5, true, () => 
dummyCreate(5))
     assertTrue(id3 > 0)
     assertCacheContains(cacheShard, id2, id3)
     assertEquals(9, cacheShard.totalPartitions)
@@ -159,7 +163,7 @@ class FetchSessionTest {
   @Test
   def testCachedLeaderEpoch(): Unit = {
     val time = new MockTime()
-    val cacheShard = new FetchSessionCacheShard(10, 1000)
+    val cacheShard = new FetchSessionCacheShard(10, 1000, Int.MaxValue, 0)
     val fetchManager = new FetchManager(time, cacheShard)
 
     val topicIds = Map("foo" -> Uuid.randomUuid(), "bar" -> 
Uuid.randomUuid()).asJava
@@ -253,7 +257,7 @@ class FetchSessionTest {
   @Test
   def testLastFetchedEpoch(): Unit = {
     val time = new MockTime()
-    val cacheShard = new FetchSessionCacheShard(10, 1000)
+    val cacheShard = new FetchSessionCacheShard(10, 1000, Int.MaxValue, 0)
     val fetchManager = new FetchManager(time, cacheShard)
 
     val topicIds = Map("foo" -> Uuid.randomUuid(), "bar" -> 
Uuid.randomUuid()).asJava
@@ -352,7 +356,7 @@ class FetchSessionTest {
   @Test
   def testFetchRequests(): Unit = {
     val time = new MockTime()
-    val cacheShard = new FetchSessionCacheShard(10, 1000)
+    val cacheShard = new FetchSessionCacheShard(10, 1000, Int.MaxValue, 0)
     val fetchManager = new FetchManager(time, cacheShard)
     val topicNames = Map(Uuid.randomUuid() -> "foo", Uuid.randomUuid() -> 
"bar").asJava
     val topicIds = topicNames.asScala.map(_.swap).asJava
@@ -541,7 +545,7 @@ class FetchSessionTest {
   @ValueSource(booleans = Array(true, false))
   def testIncrementalFetchSession(usesTopicIds: Boolean): Unit = {
     val time = new MockTime()
-    val cacheShard = new FetchSessionCacheShard(10, 1000)
+    val cacheShard = new FetchSessionCacheShard(10, 1000, Int.MaxValue, 0)
     val fetchManager = new FetchManager(time, cacheShard)
     val topicNames = if (usesTopicIds) Map(Uuid.randomUuid() -> "foo", 
Uuid.randomUuid() -> "bar").asJava else Map[Uuid, String]().asJava
     val topicIds = topicNames.asScala.map(_.swap).asJava
@@ -605,10 +609,10 @@ class FetchSessionTest {
     context2.foreachPartition((topicIdPart, _) => {
       assertEquals(reqData2Iter.next(), topicIdPart)
     })
-    assertEquals(None, context2.getFetchOffset(tp0))
+    assertEquals(Optional.empty(), context2.getFetchOffset(tp0))
     assertEquals(10, context2.getFetchOffset(tp1).get)
     assertEquals(15, context2.getFetchOffset(tp2).get)
-    assertEquals(None, context2.getFetchOffset(new TopicIdPartition(barId, new 
TopicPartition("bar", 2))))
+    assertEquals(Optional.empty(), context2.getFetchOffset(new 
TopicIdPartition(barId, new TopicPartition("bar", 2))))
     val respData2 = new util.LinkedHashMap[TopicIdPartition, 
FetchResponseData.PartitionData]
     respData2.put(tp1, new FetchResponseData.PartitionData()
         .setPartitionIndex(1)
@@ -630,7 +634,7 @@ class FetchSessionTest {
   @Test
   def testFetchSessionWithUnknownIdOldRequestVersion(): Unit = {
     val time = new MockTime()
-    val cacheShard = new FetchSessionCacheShard(10, 1000)
+    val cacheShard = new FetchSessionCacheShard(10, 1000, Int.MaxValue, 0)
     val fetchManager = new FetchManager(time, cacheShard)
     val topicNames = Map(Uuid.randomUuid() -> "foo", Uuid.randomUuid() -> 
"bar").asJava
     val topicIds = topicNames.asScala.map(_.swap).asJava
@@ -678,7 +682,7 @@ class FetchSessionTest {
   @Test
   def testFetchSessionWithUnknownId(): Unit = {
     val time = new MockTime()
-    val cacheShard = new FetchSessionCacheShard(10, 1000)
+    val cacheShard = new FetchSessionCacheShard(10, 1000, Int.MaxValue, 0)
     val fetchManager = new FetchManager(time, cacheShard)
     val fooId = Uuid.randomUuid()
     val barId = Uuid.randomUuid()
@@ -787,7 +791,7 @@ class FetchSessionTest {
   @Test
   def testIncrementalFetchSessionWithIdsWhenSessionDoesNotUseIds() : Unit = {
     val time = new MockTime()
-    val cacheShard = new FetchSessionCacheShard(10, 1000)
+    val cacheShard = new FetchSessionCacheShard(10, 1000, Int.MaxValue, 0)
     val fetchManager = new FetchManager(time, cacheShard)
     val topicNames = new util.HashMap[Uuid, String]()
     val foo0 = new TopicIdPartition(Uuid.ZERO_UUID, new TopicPartition("foo", 
0))
@@ -841,7 +845,7 @@ class FetchSessionTest {
   @Test
   def testIncrementalFetchSessionWithoutIdsWhenSessionUsesIds() : Unit = {
     val time = new MockTime()
-    val cacheShard = new FetchSessionCacheShard(10, 1000)
+    val cacheShard = new FetchSessionCacheShard(10, 1000, Int.MaxValue, 0)
     val fetchManager = new FetchManager(time, cacheShard)
     val fooId = Uuid.randomUuid()
     val topicNames = new util.HashMap[Uuid, String]()
@@ -898,7 +902,7 @@ class FetchSessionTest {
   @Test
   def testFetchSessionUpdateTopicIdsBrokerSide(): Unit = {
     val time = new MockTime()
-    val cacheShard = new FetchSessionCacheShard(10, 1000)
+    val cacheShard = new FetchSessionCacheShard(10, 1000, Int.MaxValue, 0)
     val fetchManager = new FetchManager(time, cacheShard)
     val topicNames = Map(Uuid.randomUuid() -> "foo", Uuid.randomUuid() -> 
"bar").asJava
     val topicIds = topicNames.asScala.map(_.swap).asJava
@@ -990,7 +994,7 @@ class FetchSessionTest {
   @Test
   def testResolveUnknownPartitions(): Unit = {
     val time = new MockTime()
-    val cacheShard = new FetchSessionCacheShard(10, 1000)
+    val cacheShard = new FetchSessionCacheShard(10, 1000, Int.MaxValue, 0)
     val fetchManager = new FetchManager(time, cacheShard)
 
     def newContext(
@@ -1112,7 +1116,7 @@ class FetchSessionTest {
   @MethodSource(Array("idUsageCombinations"))
   def testToForgetPartitions(fooStartsResolved: Boolean, fooEndsResolved: 
Boolean): Unit = {
     val time = new MockTime()
-    val cacheShard = new FetchSessionCacheShard(10, 1000)
+    val cacheShard = new FetchSessionCacheShard(10, 1000, Int.MaxValue, 0)
     val fetchManager = new FetchManager(time, cacheShard)
 
     def newContext(
@@ -1212,7 +1216,7 @@ class FetchSessionTest {
   @Test
   def testUpdateAndGenerateResponseData(): Unit = {
     val time = new MockTime()
-    val cacheShard = new FetchSessionCacheShard(10, 1000)
+    val cacheShard = new FetchSessionCacheShard(10, 1000, Int.MaxValue, 0)
     val fetchManager = new FetchManager(time, cacheShard)
 
     def newContext(
@@ -1323,7 +1327,7 @@ class FetchSessionTest {
   def testFetchSessionExpiration(): Unit = {
     val time = new MockTime()
     // set maximum entries to 2 to allow for eviction later
-    val cacheShard = new FetchSessionCacheShard(2, 1000)
+    val cacheShard = new FetchSessionCacheShard(2, 1000, Int.MaxValue, 0)
     val fetchManager = new FetchManager(time, cacheShard)
     val fooId = Uuid.randomUuid()
     val topicNames = Map(fooId -> "foo").asJava
@@ -1363,7 +1367,7 @@ class FetchSessionTest {
     assertEquals(2, session1resp.responseData(topicNames, 
session1request1.version).size)
 
     // check session entered into case
-    assertTrue(cacheShard.get(session1resp.sessionId()).isDefined)
+    assertTrue(cacheShard.get(session1resp.sessionId()).isPresent)
     time.sleep(500)
 
     // Create a second new fetch session
@@ -1400,8 +1404,8 @@ class FetchSessionTest {
     assertEquals(2, session2resp.responseData(topicNames, 
session2request1.version()).size())
 
     // both newly created entries are present in cache
-    assertTrue(cacheShard.get(session1resp.sessionId()).isDefined)
-    assertTrue(cacheShard.get(session2resp.sessionId()).isDefined)
+    assertTrue(cacheShard.get(session1resp.sessionId()).isPresent)
+    assertTrue(cacheShard.get(session2resp.sessionId()).isPresent)
     time.sleep(500)
 
     // Create an incremental fetch request for session 1
@@ -1457,16 +1461,16 @@ class FetchSessionTest {
     assertTrue(session3resp.sessionId() != INVALID_SESSION_ID)
     assertEquals(2, session3resp.responseData(topicNames, 
session3request1.version).size)
 
-    assertTrue(cacheShard.get(session1resp.sessionId()).isDefined)
-    assertFalse(cacheShard.get(session2resp.sessionId()).isDefined, "session 2 
should have been evicted by latest session, as session 1 was used more 
recently")
-    assertTrue(cacheShard.get(session3resp.sessionId()).isDefined)
+    assertTrue(cacheShard.get(session1resp.sessionId()).isPresent)
+    assertFalse(cacheShard.get(session2resp.sessionId()).isPresent, "session 2 
should have been evicted by latest session, as session 1 was used more 
recently")
+    assertTrue(cacheShard.get(session3resp.sessionId()).isPresent)
   }
 
   @Test
   def testPrivilegedSessionHandling(): Unit = {
     val time = new MockTime()
     // set maximum entries to 2 to allow for eviction later
-    val cacheShard = new FetchSessionCacheShard(2, 1000)
+    val cacheShard = new FetchSessionCacheShard(2, 1000, Int.MaxValue, 0)
     val fetchManager = new FetchManager(time, cacheShard)
     val fooId = Uuid.randomUuid()
     val topicNames = Map(fooId -> "foo").asJava
@@ -1544,8 +1548,8 @@ class FetchSessionTest {
     assertEquals(2, session2resp.responseData(topicNames, 
session2request.version).size)
 
     // both newly created entries are present in cache
-    assertTrue(cacheShard.get(session1resp.sessionId()).isDefined)
-    assertTrue(cacheShard.get(session2resp.sessionId()).isDefined)
+    assertTrue(cacheShard.get(session1resp.sessionId()).isPresent)
+    assertTrue(cacheShard.get(session2resp.sessionId()).isPresent)
     assertEquals(2, cacheShard.size)
     time.sleep(500)
 
@@ -1583,11 +1587,11 @@ class FetchSessionTest {
     assertTrue(session3resp.sessionId() != INVALID_SESSION_ID)
     assertEquals(2, session3resp.responseData(topicNames, 
session3request.version).size)
 
-    assertTrue(cacheShard.get(session1resp.sessionId()).isDefined)
+    assertTrue(cacheShard.get(session1resp.sessionId()).isPresent)
     // even though session 2 is more recent than session 1, and has not 
reached expiry time, it is less
     // privileged than session 2, and thus session 3 should be entered and 
session 2 evicted.
-    assertFalse(cacheShard.get(session2resp.sessionId()).isDefined, "session 2 
should have been evicted by session 3")
-    assertTrue(cacheShard.get(session3resp.sessionId()).isDefined)
+    assertFalse(cacheShard.get(session2resp.sessionId()).isPresent, "session 2 
should have been evicted by session 3")
+    assertTrue(cacheShard.get(session3resp.sessionId()).isPresent)
     assertEquals(2, cacheShard.size)
 
     time.sleep(501)
@@ -1626,16 +1630,16 @@ class FetchSessionTest {
     assertTrue(session4resp.sessionId() != INVALID_SESSION_ID)
     assertEquals(2, session4resp.responseData(topicNames, 
session4request.version).size)
 
-    assertFalse(cacheShard.get(session1resp.sessionId()).isDefined, "session 1 
should have been evicted by session 4 even though it is privileged as it has 
hit eviction time")
-    assertTrue(cacheShard.get(session3resp.sessionId()).isDefined)
-    assertTrue(cacheShard.get(session4resp.sessionId()).isDefined)
+    assertFalse(cacheShard.get(session1resp.sessionId()).isPresent, "session 1 
should have been evicted by session 4 even though it is privileged as it has 
hit eviction time")
+    assertTrue(cacheShard.get(session3resp.sessionId()).isPresent)
+    assertTrue(cacheShard.get(session4resp.sessionId()).isPresent)
     assertEquals(2, cacheShard.size)
   }
 
   @Test
   def testZeroSizeFetchSession(): Unit = {
     val time = new MockTime()
-    val cacheShard = new FetchSessionCacheShard(10, 1000)
+    val cacheShard = new FetchSessionCacheShard(10, 1000, Int.MaxValue, 0)
     val fetchManager = new FetchManager(time, cacheShard)
     val fooId = Uuid.randomUuid()
     val topicNames = Map(fooId -> "foo").asJava
@@ -1700,7 +1704,7 @@ class FetchSessionTest {
   @Test
   def testDivergingEpoch(): Unit = {
     val time = new MockTime()
-    val cacheShard = new FetchSessionCacheShard(10, 1000)
+    val cacheShard = new FetchSessionCacheShard(10, 1000, Int.MaxValue, 0)
     val fetchManager = new FetchManager(time, cacheShard)
     val topicNames = Map(Uuid.randomUuid() -> "foo", Uuid.randomUuid() -> 
"bar").asJava
     val topicIds = topicNames.asScala.map(_.swap).asJava
@@ -1785,7 +1789,7 @@ class FetchSessionTest {
   @Test
   def testDeprioritizesPartitionsWithRecordsOnly(): Unit = {
     val time = new MockTime()
-    val cacheShard = new FetchSessionCacheShard(10, 1000)
+    val cacheShard = new FetchSessionCacheShard(10, 1000, Int.MaxValue, 0)
     val fetchManager = new FetchManager(time, cacheShard)
     val topicIds = Map("foo" -> Uuid.randomUuid(), "bar" -> Uuid.randomUuid(), 
"zar" -> Uuid.randomUuid()).asJava
     val topicNames = topicIds.asScala.map(_.swap).asJava
@@ -1799,7 +1803,7 @@ class FetchSessionTest {
     reqData.put(tp3, new FetchRequest.PartitionData(tp3.topicId, 100, 0, 1000, 
Optional.of(5), Optional.of(4)))
 
     // Full fetch context returns all partitions in the response
-    val context1 = fetchManager.newContext(ApiKeys.FETCH.latestVersion(), 
JFetchMetadata.INITIAL, isFollower = false,
+    val context1 = fetchManager.newContext(ApiKeys.FETCH.latestVersion(), 
JFetchMetadata.INITIAL, false,
      reqData, Collections.emptyList(), topicNames)
     assertEquals(classOf[FullFetchContext], context1.getClass)
 
@@ -1827,7 +1831,7 @@ class FetchSessionTest {
 
     // Incremental fetch context returns partitions with changes but only 
deprioritizes
     // the partitions with records
-    val context2 = fetchManager.newContext(ApiKeys.FETCH.latestVersion(), new 
JFetchMetadata(resp1.sessionId, 1), isFollower = false,
+    val context2 = fetchManager.newContext(ApiKeys.FETCH.latestVersion(), new 
JFetchMetadata(resp1.sessionId, 1), false,
       reqData, Collections.emptyList(), topicNames)
     assertEquals(classOf[IncrementalFetchContext], context2.getClass)
 
@@ -1945,7 +1949,7 @@ class FetchSessionTest {
     // Given
     val numShards = 8
     val sessionIdRange = Int.MaxValue / numShards
-    val cacheShards = (0 until numShards).map(shardNum => new 
FetchSessionCacheShard(10, 1000, sessionIdRange, shardNum))
+    val cacheShards = (0 until numShards).map(shardNum => new 
FetchSessionCacheShard(10, 1000, sessionIdRange, shardNum)).asJava
     val cache = new FetchSessionCache(cacheShards)
 
     // When
@@ -1954,9 +1958,9 @@ class FetchSessionTest {
     val cache2 = cache.getCacheShard(sessionIdRange * 2)
 
     // Then
-    assertEquals(cache0, cacheShards(0))
-    assertEquals(cache1, cacheShards(1))
-    assertEquals(cache2, cacheShards(2))
+    assertEquals(cache0, cacheShards.get(0))
+    assertEquals(cache1, cacheShards.get(1))
+    assertEquals(cache2, cacheShards.get(2))
     assertThrows(classOf[IndexOutOfBoundsException], () => 
cache.getCacheShard(sessionIdRange * numShards))
   }
 
@@ -1965,12 +1969,12 @@ class FetchSessionTest {
     // Given
     val numShards = 8
     val sessionIdRange = Int.MaxValue / numShards
-    val cacheShards = (0 until numShards).map(shardNum => new 
FetchSessionCacheShard(10, 1000, sessionIdRange, shardNum))
+    val cacheShards = (0 until numShards).map(shardNum => new 
FetchSessionCacheShard(10, 1000, sessionIdRange, shardNum)).asJava
     val cache = new FetchSessionCache(cacheShards)
 
     // When / Then
     (0 until numShards*2).foreach { shardNum =>
-      assertEquals(cacheShards(shardNum % numShards), cache.getNextCacheShard)
+      assertEquals(cacheShards.get(shardNum % numShards), 
cache.getNextCacheShard)
     }
   }
 
@@ -1978,15 +1982,15 @@ class FetchSessionTest {
   def testFetchSessionCache_RoundRobinsIntoShards_WhenIntegerOverflows(): Unit 
= {
     // Given
     val maxInteger = Int.MaxValue
-    FetchSessionCache.counter.set(maxInteger + 1)
+    FetchSessionCache.COUNTER.set(maxInteger + 1)
     val numShards = 8
     val sessionIdRange = Int.MaxValue / numShards
-    val cacheShards = (0 until numShards).map(shardNum => new 
FetchSessionCacheShard(10, 1000, sessionIdRange, shardNum))
+    val cacheShards = (0 until numShards).map(shardNum => new 
FetchSessionCacheShard(10, 1000, sessionIdRange, shardNum)).asJava
     val cache = new FetchSessionCache(cacheShards)
 
     // When / Then
     (0 until numShards*2).foreach { shardNum =>
-      assertEquals(cacheShards(shardNum % numShards), cache.getNextCacheShard)
+      assertEquals(cacheShards.get(shardNum % numShards), 
cache.getNextCacheShard)
     }
   }
 }
diff --git a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala 
b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
index 6a26900a6d7..e6c6d377696 100644
--- a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
+++ b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
@@ -87,7 +87,8 @@ import org.apache.kafka.network.Session
 import org.apache.kafka.network.metrics.{RequestChannelMetrics, RequestMetrics}
 import org.apache.kafka.raft.{KRaftConfigs, QuorumConfig}
 import org.apache.kafka.security.authorizer.AclEntry
-import org.apache.kafka.server.{ClientMetricsManager, SimpleApiVersionManager}
+import org.apache.kafka.server.FetchContext.FullFetchContext
+import org.apache.kafka.server.{ClientMetricsManager, FetchManager, 
FetchSessionCacheShard, SimpleApiVersionManager}
 import org.apache.kafka.server.authorizer.{Action, AuthorizationResult, 
Authorizer}
 import org.apache.kafka.server.common.{FeatureVersion, FinalizedFeatures, 
GroupVersion, KRaftVersion, MetadataVersion, RequestLocal, ShareVersion, 
StreamsVersion, TransactionVersion}
 import org.apache.kafka.server.config.{ReplicationConfigs, ServerConfigs, 
ServerLogConfigs}
@@ -4435,9 +4436,7 @@ class KafkaApisTest extends Logging {
       Optional.empty()))
     val fetchDataBuilder = util.Map.of(tp, new 
FetchRequest.PartitionData(Uuid.ZERO_UUID, 0, 0, 1000,
       Optional.empty()))
-    val fetchMetadata = new JFetchMetadata(0, 0)
-    val fetchContext = new FullFetchContext(time, new 
FetchSessionCacheShard(1000, 100),
-      fetchMetadata, fetchData, false, false)
+    val fetchContext = new FullFetchContext(time, new 
FetchSessionCacheShard(1000, 100, Int.MaxValue, 0), fetchData, false, false)
     when(fetchManager.newContext(
       any[Short],
       any[JFetchMetadata],
@@ -4488,8 +4487,7 @@ class KafkaApisTest extends Logging {
     val fetchDataBuilder = util.Map.of(foo.topicPartition, new 
FetchRequest.PartitionData(foo.topicId, 0, 0, 1000,
       Optional.empty()))
     val fetchMetadata = new JFetchMetadata(0, 0)
-    val fetchContext = new FullFetchContext(time, new 
FetchSessionCacheShard(1000, 100),
-      fetchMetadata, fetchData, true, replicaId >= 0)
+    val fetchContext = new FullFetchContext(time, new 
FetchSessionCacheShard(1000, 100, Int.MaxValue, 0), fetchData, true, replicaId 
>= 0)
     // We expect to have the resolved partition, but we will simulate an 
unknown one with the fetchContext we return.
     when(fetchManager.newContext(
       ApiKeys.FETCH.latestVersion,
@@ -4558,9 +4556,7 @@ class KafkaApisTest extends Logging {
       Optional.empty()))
     val fetchDataBuilder = util.Map.of(tp, new 
FetchRequest.PartitionData(topicId, 0, 0, 1000,
       Optional.empty()))
-    val fetchMetadata = new JFetchMetadata(0, 0)
-    val fetchContext = new FullFetchContext(time, new 
FetchSessionCacheShard(1000, 100),
-      fetchMetadata, fetchData, true, false)
+    val fetchContext = new FullFetchContext(time, new 
FetchSessionCacheShard(1000, 100, Int.MaxValue, 0), fetchData, true, false)
     when(fetchManager.newContext(
       any[Short],
       any[JFetchMetadata],
@@ -9803,9 +9799,7 @@ class KafkaApisTest extends Logging {
         Optional.empty(), OptionalLong.empty(), Optional.empty(), 
OptionalInt.empty(), isReassigning)))
     })
 
-    val fetchMetadata = new JFetchMetadata(0, 0)
-    val fetchContext = new FullFetchContext(time, new 
FetchSessionCacheShard(1000, 100),
-      fetchMetadata, fetchData, true, true)
+    val fetchContext = new FullFetchContext(time, new 
FetchSessionCacheShard(1000, 100, Int.MaxValue, 0), fetchData, true, true)
     when(fetchManager.newContext(
       any[Short],
       any[JFetchMetadata],
@@ -14015,7 +14009,7 @@ class KafkaApisTest extends Logging {
     val topicId1 = Uuid.randomUuid
     metadataCache = initializeMetadataCacheWithShareGroupsEnabled()
     addTopicToMetadataCache(topicName1, 2, topicId = topicId1)
-    val topicCollection = new AlterShareGroupOffsetsRequestTopicCollection();
+    val topicCollection = new AlterShareGroupOffsetsRequestTopicCollection()
     topicCollection.addAll(util.List.of(
       new 
AlterShareGroupOffsetsRequestData.AlterShareGroupOffsetsRequestTopic()
         .setTopicName(topicName1)
@@ -14060,7 +14054,7 @@ class KafkaApisTest extends Logging {
     metadataCache = initializeMetadataCacheWithShareGroupsEnabled()
     addTopicToMetadataCache(topicName1, 2, topicId = topicId1)
     addTopicToMetadataCache(topicName2, 1, topicId = topicId2)
-    val topicCollection = new AlterShareGroupOffsetsRequestTopicCollection();
+    val topicCollection = new AlterShareGroupOffsetsRequestTopicCollection()
     topicCollection.addAll(util.List.of(
       new 
AlterShareGroupOffsetsRequestData.AlterShareGroupOffsetsRequestTopic()
         .setTopicName(topicName1)
@@ -14148,7 +14142,7 @@ class KafkaApisTest extends Logging {
     val topicId1 = Uuid.randomUuid
     metadataCache = initializeMetadataCacheWithShareGroupsEnabled()
     addTopicToMetadataCache(topicName1, 2, topicId = topicId1)
-    val topicCollection = new AlterShareGroupOffsetsRequestTopicCollection();
+    val topicCollection = new AlterShareGroupOffsetsRequestTopicCollection()
     topicCollection.addAll(util.List.of(
       new 
AlterShareGroupOffsetsRequestData.AlterShareGroupOffsetsRequestTopic()
         .setTopicName(topicName1)
diff --git 
a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/metadata/KRaftMetadataRequestBenchmark.java
 
b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/metadata/KRaftMetadataRequestBenchmark.java
index 7768cdfa92f..5b0cf9de898 100644
--- 
a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/metadata/KRaftMetadataRequestBenchmark.java
+++ 
b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/metadata/KRaftMetadataRequestBenchmark.java
@@ -21,7 +21,6 @@ import kafka.coordinator.transaction.TransactionCoordinator;
 import kafka.network.RequestChannel;
 import kafka.server.AutoTopicCreationManager;
 import kafka.server.ClientRequestQuotaManager;
-import kafka.server.FetchManager;
 import kafka.server.ForwardingManager;
 import kafka.server.KafkaApis;
 import kafka.server.KafkaConfig;
@@ -59,6 +58,7 @@ import org.apache.kafka.network.metrics.RequestChannelMetrics;
 import org.apache.kafka.raft.KRaftConfigs;
 import org.apache.kafka.raft.QuorumConfig;
 import org.apache.kafka.server.ClientMetricsManager;
+import org.apache.kafka.server.FetchManager;
 import org.apache.kafka.server.SimpleApiVersionManager;
 import org.apache.kafka.server.common.FinalizedFeatures;
 import org.apache.kafka.server.common.KRaftVersion;
diff --git a/server/src/main/java/org/apache/kafka/server/FetchContext.java 
b/server/src/main/java/org/apache/kafka/server/FetchContext.java
new file mode 100644
index 00000000000..69c8dffb752
--- /dev/null
+++ b/server/src/main/java/org/apache/kafka/server/FetchContext.java
@@ -0,0 +1,385 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.server;
+
+import org.apache.kafka.common.Node;
+import org.apache.kafka.common.TopicIdPartition;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.Uuid;
+import org.apache.kafka.common.message.FetchResponseData;
+import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.requests.FetchMetadata;
+import org.apache.kafka.common.requests.FetchRequest;
+import org.apache.kafka.common.requests.FetchResponse;
+import org.apache.kafka.common.utils.ImplicitLinkedHashCollection;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.server.FetchSession.CachedPartition;
+import org.apache.kafka.server.FetchSession.FetchSessionCache;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.NoSuchElementException;
+import java.util.Optional;
+import java.util.function.BiConsumer;
+
+import static 
org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID;
+
+public sealed interface FetchContext {
+    /**
+     * Get the fetch offset for a given partition.
+     */
+    Optional<Long> getFetchOffset(TopicIdPartition part);
+
+    /**
+     * Apply a function to each partition in the fetch request.
+     */
+    void foreachPartition(BiConsumer<TopicIdPartition, 
FetchRequest.PartitionData> fun);
+
+    /**
+     * Get the response size to be used for quota computation. Since we are 
returning an empty response in case of
+     * throttling, we are not supposed to update the context until we know 
that we are not going to throttle.
+     */
+    int getResponseSize(LinkedHashMap<TopicIdPartition, 
FetchResponseData.PartitionData> updates, short versionId);
+
+    /**
+     * Updates the fetch context with new partition information.  Generates 
response data.
+     * The response data may require subsequent down-conversion.
+     */
+    FetchResponse 
updateAndGenerateResponseData(LinkedHashMap<TopicIdPartition, 
FetchResponseData.PartitionData> updates, List<Node> nodeEndpoints);
+
+    default String partitionsToLogString(Collection<TopicIdPartition> 
partitions, boolean isTraceEnabled) {
+        return FetchSession.partitionsToLogString(partitions, isTraceEnabled);
+    }
+
+    /**
+     * Return an empty throttled response due to quota violation.
+     */
+    default FetchResponse getThrottledResponse(int throttleTimeMs, List<Node> 
nodeEndpoints) {
+        return FetchResponse.of(Errors.NONE, throttleTimeMs, 
INVALID_SESSION_ID, new LinkedHashMap<>(), nodeEndpoints);
+    }
+
+    /**
+     * The fetch context for a fetch request that had a session error.
+     */
+    final class SessionErrorContext implements FetchContext {
+        private static final Logger LOGGER = 
LoggerFactory.getLogger(SessionErrorContext.class);
+
+        private final Errors error;
+
+        public SessionErrorContext(Errors error) {
+            this.error = error;
+        }
+
+        @Override
+        public Optional<Long> getFetchOffset(TopicIdPartition part) {
+            return Optional.empty();
+        }
+
+        @Override
+        public void foreachPartition(BiConsumer<TopicIdPartition, 
FetchRequest.PartitionData> fun) {
+        }
+
+        @Override
+        public int getResponseSize(LinkedHashMap<TopicIdPartition, 
FetchResponseData.PartitionData> updates, short versionId) {
+            return FetchResponse.sizeOf(versionId, 
Collections.emptyIterator());
+        }
+
+        /**
+         * Because of the fetch session error, we don't know what partitions 
were supposed to be in this request.
+         */
+        @Override
+        public FetchResponse 
updateAndGenerateResponseData(LinkedHashMap<TopicIdPartition, 
FetchResponseData.PartitionData> updates,
+                                                           List<Node> 
nodeEndpoints) {
+            LOGGER.debug("Session error fetch context returning {}", error);
+            return FetchResponse.of(error, 0, INVALID_SESSION_ID, new 
LinkedHashMap<>(), nodeEndpoints);
+        }
+    }
+
+    /**
+     * The fetch context for a sessionless fetch request.
+     */
+    final class SessionlessFetchContext implements FetchContext {
+        private static final Logger LOGGER = 
LoggerFactory.getLogger(SessionlessFetchContext.class);
+
+        private final Map<TopicIdPartition, FetchRequest.PartitionData> 
fetchData;
+
+        /**
+         * @param fetchData The partition data from the fetch request.
+         */
+        public SessionlessFetchContext(Map<TopicIdPartition, 
FetchRequest.PartitionData> fetchData) {
+            this.fetchData = fetchData;
+        }
+
+        @Override
+        public Optional<Long> getFetchOffset(TopicIdPartition part) {
+            return Optional.ofNullable(fetchData.get(part)).map(data -> 
data.fetchOffset);
+        }
+
+        @Override
+        public void foreachPartition(BiConsumer<TopicIdPartition, 
FetchRequest.PartitionData> fun) {
+            fetchData.forEach(fun);
+        }
+
+        @Override
+        public int getResponseSize(LinkedHashMap<TopicIdPartition, 
FetchResponseData.PartitionData> updates, short versionId) {
+            return FetchResponse.sizeOf(versionId, 
updates.entrySet().iterator());
+        }
+
+        @Override
+        public FetchResponse 
updateAndGenerateResponseData(LinkedHashMap<TopicIdPartition, 
FetchResponseData.PartitionData> updates,
+                                                           List<Node> 
nodeEndpoints) {
+            LOGGER.debug("Sessionless fetch context returning {}", 
partitionsToLogString(updates.keySet(), LOGGER.isTraceEnabled()));
+            return FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, 
updates, nodeEndpoints);
+        }
+    }
+
+    /**
+     * The fetch context for a full fetch request.
+     */
+    final class FullFetchContext implements FetchContext {
+        private static final Logger LOGGER = 
LoggerFactory.getLogger(FullFetchContext.class);
+
+        private final Time time;
+        private final FetchSessionCache cache;
+        private final Map<TopicIdPartition, FetchRequest.PartitionData> 
fetchData;
+        private final boolean usesTopicIds;
+        private final boolean isFromFollower;
+
+        public FullFetchContext(Time time,
+                                FetchSessionCacheShard cacheShard,
+                                Map<TopicIdPartition, 
FetchRequest.PartitionData> fetchData,
+                                boolean usesTopicIds,
+                                boolean isFromFollower) {
+            this(time, new FetchSessionCache(List.of(cacheShard)), fetchData, 
usesTopicIds, isFromFollower);
+        }
+
+        /**
+         * @param time           The clock to use
+         * @param cache          The fetch session cache
+         * @param fetchData      The partition data from the fetch request
+         * @param usesTopicIds   True if this session should use topic IDs
+         * @param isFromFollower True if this fetch request came from a 
follower
+         */
+        public FullFetchContext(Time time,
+                                FetchSessionCache cache,
+                                Map<TopicIdPartition, 
FetchRequest.PartitionData> fetchData,
+                                boolean usesTopicIds,
+                                boolean isFromFollower) {
+            this.time = time;
+            this.cache = cache;
+            this.fetchData = fetchData;
+            this.usesTopicIds = usesTopicIds;
+            this.isFromFollower = isFromFollower;
+        }
+
+        @Override
+        public Optional<Long> getFetchOffset(TopicIdPartition part) {
+            return Optional.ofNullable(fetchData.get(part)).map(data -> 
data.fetchOffset);
+        }
+
+        @Override
+        public void foreachPartition(BiConsumer<TopicIdPartition, 
FetchRequest.PartitionData> fun) {
+            fetchData.forEach(fun);
+        }
+
+        @Override
+        public int getResponseSize(LinkedHashMap<TopicIdPartition, 
FetchResponseData.PartitionData> updates, short versionId) {
+            return FetchResponse.sizeOf(versionId, 
updates.entrySet().iterator());
+        }
+
+        @Override
+        public FetchResponse 
updateAndGenerateResponseData(LinkedHashMap<TopicIdPartition, 
FetchResponseData.PartitionData> updates,
+                                                           List<Node> 
nodeEndpoints) {
+            FetchSessionCacheShard cacheShard = cache.getNextCacheShard();
+            int responseSessionId = 
cacheShard.maybeCreateSession(time.milliseconds(), isFromFollower,
+                updates.size(), usesTopicIds, () -> createNewSession(updates));
+            LOGGER.debug("Full fetch context with session id {} returning {}",
+                responseSessionId, partitionsToLogString(updates.keySet(), 
LOGGER.isTraceEnabled()));
+
+            return FetchResponse.of(Errors.NONE, 0, responseSessionId, 
updates, nodeEndpoints);
+        }
+
+        private ImplicitLinkedHashCollection<CachedPartition> createNewSession(
+                LinkedHashMap<TopicIdPartition, 
FetchResponseData.PartitionData> updates
+        ) {
+            ImplicitLinkedHashCollection<CachedPartition> cachedPartitions = 
new ImplicitLinkedHashCollection<>(updates.size());
+            updates.forEach((part, respData) -> {
+                FetchRequest.PartitionData reqData = fetchData.get(part);
+                cachedPartitions.mustAdd(new CachedPartition(part, reqData, 
respData));
+            });
+
+            return cachedPartitions;
+        }
+    }
+
+    /**
+     * The fetch context for an incremental fetch request.
+     */
+    final class IncrementalFetchContext implements FetchContext {
+        private static final Logger LOGGER = 
LoggerFactory.getLogger(IncrementalFetchContext.class);
+
+        private final FetchMetadata reqMetadata;
+        private final FetchSession session;
+        private final Map<Uuid, String> topicNames;
+
+        /**
+         * @param reqMetadata  The request metadata
+         * @param session      The incremental fetch request session
+         * @param topicNames   A mapping from topic ID to topic name used to 
resolve partitions already in the session
+         */
+        public IncrementalFetchContext(FetchMetadata reqMetadata,
+                                       FetchSession session,
+                                       Map<Uuid, String> topicNames) {
+            this.reqMetadata = reqMetadata;
+            this.session = session;
+            this.topicNames = topicNames;
+        }
+
+        @Override
+        public Optional<Long> getFetchOffset(TopicIdPartition part) {
+            return session.getFetchOffset(part);
+        }
+
+        @Override
+        public void foreachPartition(BiConsumer<TopicIdPartition, 
FetchRequest.PartitionData> fun) {
+            // Take the session lock and iterate over all the cached 
partitions.
+            synchronized (session) {
+                session.partitionMap().forEach(part -> {
+                    // Try to resolve an unresolved partition if it does not 
yet have a name
+                    if (session.usesTopicIds())
+                        part.maybeResolveUnknownName(topicNames);
+                    fun.accept(new TopicIdPartition(part.topicId(), new 
TopicPartition(part.topic(), part.partition())), part.reqData());
+                });
+            }
+        }
+
+        /**
+         * Iterator that goes over the given partition map and selects 
partitions that need to be included in the response.
+         * If updateFetchContextAndRemoveUnselected is set to true, the fetch 
context will be updated for the selected
+         * partitions and also remove unselected ones as they are encountered.
+         */
+        private class PartitionIterator implements 
Iterator<Map.Entry<TopicIdPartition, FetchResponseData.PartitionData>> {
+            private final Iterator<Map.Entry<TopicIdPartition, 
FetchResponseData.PartitionData>> iter;
+            private final boolean updateFetchContextAndRemoveUnselected;
+            private Map.Entry<TopicIdPartition, 
FetchResponseData.PartitionData> nextElement;
+
+            public PartitionIterator(Iterator<Map.Entry<TopicIdPartition, 
FetchResponseData.PartitionData>> iter,
+                                     boolean 
updateFetchContextAndRemoveUnselected) throws NoSuchElementException {
+                this.iter = iter;
+                this.updateFetchContextAndRemoveUnselected = 
updateFetchContextAndRemoveUnselected;
+            }
+
+            @Override
+            public boolean hasNext() {
+                while ((nextElement == null) && iter.hasNext()) {
+                    Map.Entry<TopicIdPartition, 
FetchResponseData.PartitionData> element = iter.next();
+                    TopicIdPartition topicPart = element.getKey();
+                    FetchResponseData.PartitionData respData = 
element.getValue();
+                    CachedPartition cachedPart = 
session.partitionMap().find(new CachedPartition(topicPart));
+                    boolean mustRespond = cachedPart != null && 
cachedPart.maybeUpdateResponseData(respData, 
updateFetchContextAndRemoveUnselected);
+                    if (mustRespond) {
+                        nextElement = element;
+                        if (updateFetchContextAndRemoveUnselected && 
FetchResponse.recordsSize(respData) > 0) {
+                            session.partitionMap().remove(cachedPart);
+                            session.partitionMap().mustAdd(cachedPart);
+                        }
+                    } else if (updateFetchContextAndRemoveUnselected) {
+                        iter.remove();
+                    }
+                }
+
+                return nextElement != null;
+            }
+
+            @Override
+            public Map.Entry<TopicIdPartition, 
FetchResponseData.PartitionData> next() {
+                if (!hasNext())
+                    throw new NoSuchElementException();
+
+                Map.Entry<TopicIdPartition, FetchResponseData.PartitionData> 
element = nextElement;
+                nextElement = null;
+
+                return element;
+            }
+
+            @Override
+            public void remove() {
+                throw new UnsupportedOperationException();
+            }
+        }
+
+        @Override
+        public int getResponseSize(LinkedHashMap<TopicIdPartition, 
FetchResponseData.PartitionData> updates, short versionId) {
+            synchronized (session) {
+                int expectedEpoch = 
FetchMetadata.nextEpoch(reqMetadata.epoch());
+                if (session.epoch() != expectedEpoch) {
+                    return FetchResponse.sizeOf(versionId, 
Collections.emptyIterator());
+                } else {
+                    // Pass the partition iterator which updates neither the 
fetch context nor the partition map.
+                    return FetchResponse.sizeOf(versionId, new 
PartitionIterator(updates.entrySet().iterator(), false));
+                }
+            }
+        }
+
+        @Override
+        public FetchResponse 
updateAndGenerateResponseData(LinkedHashMap<TopicIdPartition, 
FetchResponseData.PartitionData> updates,
+                                                           List<Node> 
nodeEndpoints) {
+            synchronized (session) {
+                // Check to make sure that the session epoch didn't change in 
between
+                // creating this fetch context and generating this response.
+                int expectedEpoch = 
FetchMetadata.nextEpoch(reqMetadata.epoch());
+                if (session.epoch() != expectedEpoch) {
+                    LOGGER.info("Incremental fetch session {} expected epoch 
{}, but got {}. Possible duplicate request.",
+                        session.id(), expectedEpoch, session.epoch());
+                    return 
FetchResponse.of(Errors.INVALID_FETCH_SESSION_EPOCH, 0, session.id(), new 
LinkedHashMap<>(), nodeEndpoints);
+                } else {
+                    // Iterate over the update list using PartitionIterator. 
This will prune updates which don't need to be sent
+                    PartitionIterator partitionIter = new 
PartitionIterator(updates.entrySet().iterator(), true);
+                    while (partitionIter.hasNext()) {
+                        partitionIter.next();
+                    }
+                    LOGGER.debug("Incremental fetch context with session id {} 
returning {}", session.id(),
+                        partitionsToLogString(updates.keySet(), 
LOGGER.isTraceEnabled()));
+                    return FetchResponse.of(Errors.NONE, 0, session.id(), 
updates, nodeEndpoints);
+                }
+            }
+        }
+
+        @Override
+        public FetchResponse getThrottledResponse(int throttleTimeMs, 
List<Node> nodeEndpoints) {
+            synchronized (session) {
+                // Check to make sure that the session epoch didn't change in 
between
+                // creating this fetch context and generating this response.
+                int expectedEpoch = 
FetchMetadata.nextEpoch(reqMetadata.epoch());
+                if (session.epoch() != expectedEpoch) {
+                    LOGGER.info("Incremental fetch session {} expected epoch 
{}, but got {}. Possible duplicate request.",
+                        session.id(), expectedEpoch, session.epoch());
+                    return 
FetchResponse.of(Errors.INVALID_FETCH_SESSION_EPOCH, throttleTimeMs, 
session.id(), new LinkedHashMap<>(), nodeEndpoints);
+                } else {
+                    return FetchResponse.of(Errors.NONE, throttleTimeMs, 
session.id(), new LinkedHashMap<>(), nodeEndpoints);
+                }
+            }
+        }
+    }
+}
diff --git a/server/src/main/java/org/apache/kafka/server/FetchManager.java 
b/server/src/main/java/org/apache/kafka/server/FetchManager.java
new file mode 100644
index 00000000000..87c7a6af149
--- /dev/null
+++ b/server/src/main/java/org/apache/kafka/server/FetchManager.java
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.server;
+
+import org.apache.kafka.common.TopicIdPartition;
+import org.apache.kafka.common.Uuid;
+import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.requests.FetchMetadata;
+import org.apache.kafka.common.requests.FetchRequest;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.server.FetchContext.FullFetchContext;
+import org.apache.kafka.server.FetchContext.IncrementalFetchContext;
+import org.apache.kafka.server.FetchContext.SessionErrorContext;
+import org.apache.kafka.server.FetchContext.SessionlessFetchContext;
+import org.apache.kafka.server.FetchSession.FetchSessionCache;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+import static org.apache.kafka.common.requests.FetchMetadata.FINAL_EPOCH;
+import static 
org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID;
+
+public class FetchManager {
+    private static final Logger LOGGER = 
LoggerFactory.getLogger(FetchManager.class);
+
+    private final Time time;
+    private final FetchSessionCache cache;
+
+    public FetchManager(Time time, FetchSessionCacheShard cacheShard) {
+        this(time, new FetchSessionCache(List.of(cacheShard)));
+    }
+
+    public FetchManager(Time time, FetchSessionCache cache) {
+        this.time = time;
+        this.cache = cache;
+    }
+
+    public FetchContext newContext(short reqVersion,
+                                   FetchMetadata reqMetadata,
+                                   boolean isFollower,
+                                   Map<TopicIdPartition, 
FetchRequest.PartitionData> fetchData,
+                                   List<TopicIdPartition> toForget,
+                                   Map<Uuid, String> topicNames) {
+        if (reqMetadata.isFull()) {
+            String removedFetchSessionStr = "";
+            if (reqMetadata.sessionId() != INVALID_SESSION_ID) {
+                FetchSessionCacheShard cacheShard = 
cache.getCacheShard(reqMetadata.sessionId());
+                // Any session specified in a FULL fetch request will be 
closed.
+                if (cacheShard.remove(reqMetadata.sessionId()).isPresent())
+                    removedFetchSessionStr = " Removed fetch session " + 
reqMetadata.sessionId() + ".";
+            }
+            String suffix = "";
+            FetchContext fetchContext;
+            if (reqMetadata.epoch() == FINAL_EPOCH) {
+                // If the epoch is FINAL_EPOCH, don't try to create a new 
session.
+                suffix = " Will not try to create a new session.";
+                fetchContext = new SessionlessFetchContext(fetchData);
+            } else
+                fetchContext = new FullFetchContext(time, cache, fetchData, 
reqVersion >= 13, isFollower);
+
+            LOGGER.debug("Created a new full FetchContext with {}.{}{}",
+                partitionsToLogString(fetchData.keySet()), 
removedFetchSessionStr, suffix);
+            return fetchContext;
+        } else {
+            FetchSessionCacheShard cacheShard = 
cache.getCacheShard(reqMetadata.sessionId());
+            synchronized (cacheShard) {
+                Optional<FetchSession> sessionOpt = 
cacheShard.get(reqMetadata.sessionId());
+
+                if (sessionOpt.isEmpty()) {
+                    LOGGER.debug("Session error for {}: no such session ID 
found.", reqMetadata.sessionId());
+                    return new 
SessionErrorContext(Errors.FETCH_SESSION_ID_NOT_FOUND);
+                } else {
+                    FetchSession session = sessionOpt.get();
+                    synchronized (session) {
+                        if (session.epoch() != reqMetadata.epoch()) {
+                            LOGGER.debug("Session error for {}: expected epoch 
{}, but got {} instead.",
+                                reqMetadata.sessionId(), session.epoch(), 
reqMetadata.epoch());
+
+                            return new 
SessionErrorContext(Errors.INVALID_FETCH_SESSION_EPOCH);
+                        } else if (session.usesTopicIds() && reqVersion < 13 
|| !session.usesTopicIds() && reqVersion >= 13)  {
+                            LOGGER.debug("Session error for {}: expected  {}, 
but request version {} means that we can not.",
+                                reqMetadata.sessionId(), 
session.usesTopicIds() ? "to use topic IDs" : "to not use topic IDs", 
reqVersion);
+
+                            return new 
SessionErrorContext(Errors.FETCH_SESSION_TOPIC_ID_ERROR);
+                        } else {
+                            List<List<TopicIdPartition>> lists = 
session.update(fetchData, toForget);
+                            List<TopicIdPartition> added = lists.get(0);
+                            List<TopicIdPartition> updated = lists.get(1);
+                            List<TopicIdPartition> removed = lists.get(2);
+                            if (session.isEmpty()) {
+                                LOGGER.debug("Created a new sessionless 
FetchContext and closing session id {}, " +
+                                    "epoch {}: after removing {}, there are no 
more partitions left.",
+                                    session.id(), session.epoch(), 
partitionsToLogString(removed));
+                                cacheShard.remove(session);
+
+                                return new SessionlessFetchContext(fetchData);
+                            } else {
+                                cacheShard.touch(session, time.milliseconds());
+                                
session.setEpoch(FetchMetadata.nextEpoch(session.epoch()));
+                                LOGGER.debug("Created a new incremental 
FetchContext for session id {}, " +
+                                    "epoch {}: added {}, updated {}, removed 
{}",
+                                    session.id(), session.epoch(), 
partitionsToLogString(added),
+                                    partitionsToLogString(updated), 
partitionsToLogString(removed));
+
+                                return new 
IncrementalFetchContext(reqMetadata, session, topicNames);
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+
+    private String partitionsToLogString(Collection<TopicIdPartition> 
partitions) {
+        return FetchSession.partitionsToLogString(partitions, 
LOGGER.isTraceEnabled());
+    }
+}
diff --git a/server/src/main/java/org/apache/kafka/server/FetchSession.java 
b/server/src/main/java/org/apache/kafka/server/FetchSession.java
new file mode 100644
index 00000000000..4fbd9be9431
--- /dev/null
+++ b/server/src/main/java/org/apache/kafka/server/FetchSession.java
@@ -0,0 +1,493 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.server;
+
+import org.apache.kafka.common.TopicIdPartition;
+import org.apache.kafka.common.Uuid;
+import org.apache.kafka.common.message.FetchResponseData;
+import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.requests.FetchMetadata;
+import org.apache.kafka.common.requests.FetchRequest;
+import org.apache.kafka.common.requests.FetchResponse;
+import org.apache.kafka.common.utils.ImplicitLinkedHashCollection;
+import org.apache.kafka.common.utils.Utils;
+import org.apache.kafka.server.metrics.KafkaMetricsGroup;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.concurrent.atomic.AtomicInteger;
+
+/**
+ * The fetch session.
+ * <p>
+ * Each fetch session is protected by its own lock, which must be taken before 
mutable
+ * fields are read or modified. This includes modification of the session 
partition map.
+ */
+public class FetchSession {
+    public static final String NUM_INCREMENTAL_FETCH_SESSIONS = 
"NumIncrementalFetchSessions";
+    public static final String NUM_INCREMENTAL_FETCH_PARTITIONS_CACHED = 
"NumIncrementalFetchPartitionsCached";
+    public static final String INCREMENTAL_FETCH_SESSIONS_EVICTIONS_PER_SEC = 
"IncrementalFetchSessionEvictionsPerSec";
+    static final String EVICTIONS = "evictions";
+
+    /**
+     * This is used by the FetchSessionCache to store the last known size of 
this session.
+     * If this is -1, the Session is not in the cache.
+     */
+    private int cachedSize = -1;
+
+    private final int id;
+    private final boolean privileged;
+    private final ImplicitLinkedHashCollection<CachedPartition> partitionMap;
+    private final boolean usesTopicIds;
+    private final long creationMs;
+    private volatile long lastUsedMs;
+    private volatile int epoch;
+
+    /**
+     * The fetch session.
+     *
+     * @param id                 The unique fetch session ID.
+     * @param privileged         True if this session is privileged.  Sessions 
created by followers
+     *                           are privileged; session created by consumers 
are not.
+     * @param partitionMap       The CachedPartitionMap.
+     * @param usesTopicIds       True if this session is using topic IDs
+     * @param creationMs         The time in milliseconds when this session 
was created.
+     * @param lastUsedMs         The last used time in milliseconds.  This 
should only be updated by
+     *                           FetchSessionCache#touch.
+     * @param epoch              The fetch session sequence number.
+     */
+    public FetchSession(int id,
+                        boolean privileged,
+                        ImplicitLinkedHashCollection<CachedPartition> 
partitionMap,
+                        boolean usesTopicIds,
+                        long creationMs,
+                        long lastUsedMs,
+                        int epoch) {
+        this.id = id;
+        this.privileged = privileged;
+        this.partitionMap = partitionMap;
+        this.usesTopicIds = usesTopicIds;
+        this.creationMs = creationMs;
+        this.lastUsedMs = lastUsedMs;
+        this.epoch = epoch;
+    }
+
+    public static String partitionsToLogString(Collection<TopicIdPartition> 
partitions, boolean traceEnabled) {
+        return traceEnabled
+            ? "(" + String.join(", ", partitions.toString()) + ")"
+            : partitions.size() + " partition(s)";
+    }
+
+    public synchronized ImplicitLinkedHashCollection<CachedPartition> 
partitionMap() {
+        return partitionMap;
+    }
+
+    public int id() {
+        return id;
+    }
+
+    public boolean privileged() {
+        return privileged;
+    }
+
+    public boolean usesTopicIds() {
+        return usesTopicIds;
+    }
+
+    public long creationMs() {
+        return creationMs;
+    }
+
+    public int epoch() {
+        return epoch;
+    }
+
+    public void setEpoch(int epoch) {
+        this.epoch = epoch;
+    }
+
+    public long lastUsedMs() {
+        return lastUsedMs;
+    }
+
+    public void setLastUsedMs(long lastUsedMs) {
+        this.lastUsedMs = lastUsedMs;
+    }
+
+    public synchronized int cachedSize() {
+        return cachedSize;
+    }
+
+    public synchronized void setCachedSize(int cachedSize) {
+        this.cachedSize = cachedSize;
+    }
+
+    public synchronized int size() {
+        return partitionMap.size();
+    }
+
+    public synchronized boolean isEmpty() {
+        return partitionMap.isEmpty();
+    }
+
+    public synchronized LastUsedKey lastUsedKey() {
+        return new LastUsedKey(lastUsedMs, id);
+    }
+
+    public synchronized EvictableKey evictableKey() {
+        return new EvictableKey(privileged, cachedSize, id);
+    }
+
+    public synchronized FetchMetadata metadata() {
+        return new FetchMetadata(id, epoch);
+    }
+
+    public synchronized Optional<Long> getFetchOffset(TopicIdPartition 
topicIdPartition) {
+        return Optional.ofNullable(partitionMap.find(new 
CachedPartition(topicIdPartition)))
+            .map(partition -> partition.fetchOffset);
+    }
+
+    // Update the cached partition data based on the request.
+    public synchronized List<List<TopicIdPartition>> 
update(Map<TopicIdPartition, FetchRequest.PartitionData> fetchData,
+                                                            
List<TopicIdPartition> toForget) {
+        List<TopicIdPartition> added = new ArrayList<>();
+        List<TopicIdPartition> updated = new ArrayList<>();
+        List<TopicIdPartition> removed = new ArrayList<>();
+
+        fetchData.forEach((topicPart, reqData) -> {
+            CachedPartition cachedPartitionKey = new 
CachedPartition(topicPart, reqData);
+            CachedPartition cachedPart = partitionMap.find(cachedPartitionKey);
+            if (cachedPart == null) {
+                partitionMap.mustAdd(cachedPartitionKey);
+                added.add(topicPart);
+            } else {
+                cachedPart.updateRequestParams(reqData);
+                updated.add(topicPart);
+            }
+        });
+
+        toForget.forEach(p -> {
+            if (partitionMap.remove(new CachedPartition(p)))
+                removed.add(p);
+        });
+
+        return List.of(added, updated, removed);
+    }
+
+    @Override
+    public String toString() {
+        synchronized (this) {
+            return "FetchSession(id=" + id +
+                ", privileged=" + privileged +
+                ", partitionMap.size=" + partitionMap.size() +
+                ", usesTopicIds=" + usesTopicIds +
+                ", creationMs=" + creationMs +
+                ", lastUsedMs=" + lastUsedMs +
+                ", epoch=" + epoch + ")";
+        }
+    }
+
+    /**
+     * A cached partition.
+     * <p>
+     * The broker maintains a set of these objects for each incremental fetch 
session.
+     * When an incremental fetch request is made, any partitions which are not 
explicitly
+     * enumerated in the fetch request are loaded from the cache.  Similarly, 
when an
+     * incremental fetch response is being prepared, any partitions that have 
not changed and
+     * do not have errors are left out of the response.
+     * <p>
+     * We store many of these objects, so it is important for them to be 
memory-efficient.
+     * That is why we store topic and partition separately rather than storing 
a TopicPartition
+     * object.  The TP object takes up more memory because it is a separate 
JVM object, and
+     * because it stores the cached hash code in memory.
+     * <p>
+     * Note that fetcherLogStartOffset is the LSO of the follower performing 
the fetch, whereas
+     * localLogStartOffset is the log start offset of the partition on this 
broker.
+     */
+    public static class CachedPartition implements 
ImplicitLinkedHashCollection.Element {
+
+        private volatile int cachedNext = 
ImplicitLinkedHashCollection.INVALID_INDEX;
+        private volatile int cachedPrev = 
ImplicitLinkedHashCollection.INVALID_INDEX;
+
+        private String topic;
+        private final Uuid topicId;
+        private final int partition;
+        private volatile int maxBytes;
+        private volatile long fetchOffset;
+        private long highWatermark;
+        private Optional<Integer> leaderEpoch;
+        private volatile long fetcherLogStartOffset;
+        private long localLogStartOffset;
+        private Optional<Integer> lastFetchedEpoch;
+
+        public CachedPartition(String topic, Uuid topicId, int partition) {
+            this(topic, topicId, partition, -1, -1, -1, Optional.empty(), -1, 
-1, Optional.empty());
+        }
+
+        public CachedPartition(TopicIdPartition part) {
+            this(part.topic(), part.topicId(), part.partition());
+        }
+
+        public CachedPartition(TopicIdPartition part, 
FetchRequest.PartitionData reqData) {
+            this(part.topic(), part.topicId(), part.partition(), 
reqData.maxBytes, reqData.fetchOffset, -1,
+                reqData.currentLeaderEpoch, reqData.logStartOffset, -1, 
reqData.lastFetchedEpoch);
+        }
+
+        public CachedPartition(TopicIdPartition part, 
FetchRequest.PartitionData reqData, FetchResponseData.PartitionData respData) {
+            this(part.topic(), part.topicId(), part.partition(), 
reqData.maxBytes, reqData.fetchOffset, respData.highWatermark(),
+                reqData.currentLeaderEpoch, reqData.logStartOffset, 
respData.logStartOffset(), reqData.lastFetchedEpoch);
+        }
+
+        public CachedPartition(String topic,
+                               Uuid topicId,
+                               int partition,
+                               int maxBytes,
+                               long fetchOffset,
+                               long highWatermark,
+                               Optional<Integer> leaderEpoch,
+                               long fetcherLogStartOffset,
+                               long localLogStartOffset,
+                               Optional<Integer> lastFetchedEpoch) {
+            this.topic = topic;
+            this.topicId = topicId;
+            this.partition = partition;
+            this.maxBytes = maxBytes;
+            this.fetchOffset = fetchOffset;
+            this.highWatermark = highWatermark;
+            this.leaderEpoch = leaderEpoch;
+            this.fetcherLogStartOffset = fetcherLogStartOffset;
+            this.localLogStartOffset = localLogStartOffset;
+            this.lastFetchedEpoch = lastFetchedEpoch;
+        }
+
+        @Override
+        public int next() {
+            return cachedNext;
+        }
+
+        @Override
+        public void setNext(int next) {
+            this.cachedNext = next;
+        }
+
+        @Override
+        public int prev() {
+            return cachedPrev;
+        }
+
+        @Override
+        public void setPrev(int prev) {
+            this.cachedPrev = prev;
+        }
+
+        public String topic() {
+            return topic;
+        }
+
+        public Uuid topicId() {
+            return topicId;
+        }
+
+        public int partition() {
+            return partition;
+        }
+
+        public FetchRequest.PartitionData reqData() {
+            return new FetchRequest.PartitionData(topicId, fetchOffset, 
fetcherLogStartOffset, maxBytes, leaderEpoch, lastFetchedEpoch);
+        }
+
+        public void updateRequestParams(FetchRequest.PartitionData reqData) {
+            // Update our cached request parameters.
+            maxBytes = reqData.maxBytes;
+            fetchOffset = reqData.fetchOffset;
+            fetcherLogStartOffset = reqData.logStartOffset;
+            leaderEpoch = reqData.currentLeaderEpoch;
+            lastFetchedEpoch = reqData.lastFetchedEpoch;
+        }
+
+        public void maybeResolveUnknownName(Map<Uuid, String> topicNames) {
+            if (topic == null)
+                topic = topicNames.get(topicId);
+        }
+
+        /**
+         * Determine whether the specified cached partition should be included 
in the FetchResponse we send back to
+         * the fetcher and update it if requested.
+         * <p>
+         * This function should be called while holding the appropriate 
session lock.
+         *
+         * @param respData partition data
+         * @param updateResponseData if set to true, update this 
CachedPartition with new request and response data.
+         * @return True if this partition should be included in the response; 
false if it can be omitted.
+         */
+        public boolean maybeUpdateResponseData(FetchResponseData.PartitionData 
respData, boolean updateResponseData) {
+            // Check the response data.
+            boolean mustRespond = false;
+            if (FetchResponse.recordsSize(respData) > 0) {
+                // Partitions with new data are always included in the 
response.
+                mustRespond = true;
+            }
+            if (highWatermark != respData.highWatermark()) {
+                mustRespond = true;
+                if (updateResponseData)
+                    highWatermark = respData.highWatermark();
+            }
+            if (localLogStartOffset != respData.logStartOffset()) {
+                mustRespond = true;
+                if (updateResponseData)
+                    localLogStartOffset = respData.logStartOffset();
+            }
+            if (FetchResponse.isPreferredReplica(respData)) {
+                // If the broker computed a preferred read replica, we need to 
include it in the response
+                mustRespond = true;
+            }
+            if (respData.errorCode() != Errors.NONE.code()) {
+                // Partitions with errors are always included in the response.
+                // We also set the cached highWatermark to an invalid offset, 
-1.
+                // This ensures that when the error goes away, we re-send the 
partition.
+                if (updateResponseData)
+                    highWatermark = -1;
+                mustRespond = true;
+            }
+            if (FetchResponse.isDivergingEpoch(respData)) {
+                // Partitions with diverging epoch are always included in 
response to trigger truncation.
+                mustRespond = true;
+            }
+
+            return mustRespond;
+        }
+
+        /**
+         * We have different equality checks depending on whether topic IDs 
are used.
+         * This means we need a different hash function as well. We use name 
to calculate the hash if the ID is zero and unused.
+         * Otherwise, we use the topic ID in the hash calculation.
+         *
+         * @return the hash code for the CachedPartition depending on what 
request version we are using.
+         */
+        @Override
+        public int hashCode() {
+            if (topicId != Uuid.ZERO_UUID)
+                return (31 * partition) + topicId.hashCode();
+            else
+                return (31 * partition) + topic.hashCode();
+        }
+
+        /**
+         * We have different equality checks depending on whether topic IDs 
are used.
+         * <p>
+         * This is because when we use topic IDs, a partition with a given ID 
and an unknown name is the same as a partition with that
+         * ID and a known name. This means we can only use topic ID and 
partition when determining equality.
+         * <p>
+         * On the other hand, if we are using topic names, all IDs are zero. 
This means we can only use topic name and partition
+         * when determining equality.
+         */
+        @Override
+        public boolean equals(Object that) {
+            if (that instanceof CachedPartition part) {
+                boolean condition;
+                if (this.topicId != Uuid.ZERO_UUID)
+                    condition = this.partition == part.partition && 
this.topicId.equals(part.topicId);
+                else
+                    condition = this.partition == part.partition && 
this.topic.equals(part.topic);
+
+                return this == part || condition;
+            }
+
+            return false;
+        }
+
+        @Override
+        public String toString() {
+            synchronized (this) {
+                return "CachedPartition(topic=" + topic +
+                    ", topicId=" + topicId +
+                    ", partition=" + partition +
+                    ", maxBytes=" + maxBytes +
+                    ", fetchOffset=" + fetchOffset +
+                    ", highWatermark=" + highWatermark +
+                    ", fetcherLogStartOffset=" + fetcherLogStartOffset +
+                    ", localLogStartOffset=" + localLogStartOffset  +
+                    ")";
+            }
+        }
+    }
+
+    public record LastUsedKey(long lastUsedMs, int id) implements 
Comparable<LastUsedKey> {
+        @Override
+        public int compareTo(LastUsedKey other) {
+            if (this.lastUsedMs != other.lastUsedMs)
+                return Long.compare(this.lastUsedMs, other.lastUsedMs);
+
+            return Integer.compare(this.id, other.id);
+        }
+    }
+
+    public record EvictableKey(boolean privileged, int size, int id) 
implements Comparable<EvictableKey> {
+        @Override
+        public int compareTo(EvictableKey other) {
+            if (this.privileged != other.privileged)
+                return Boolean.compare(this.privileged, other.privileged);
+
+            if (this.size != other.size)
+                return Integer.compare(this.size, other.size);
+
+            return Integer.compare(this.id, other.id);
+        }
+    }
+
+    public static class FetchSessionCache {
+        // Changing the package or class name may cause incompatibility with 
existing code and metrics configuration
+        private static final String METRICS_PACKAGE = "kafka.server";
+        private static final String METRICS_CLASS_NAME = "FetchSessionCache";
+
+        public static final KafkaMetricsGroup METRICS_GROUP = new 
KafkaMetricsGroup(METRICS_PACKAGE, METRICS_CLASS_NAME);
+        public static final AtomicInteger COUNTER = new AtomicInteger(0);
+
+        private final List<FetchSessionCacheShard> cacheShards;
+
+        public FetchSessionCache(List<FetchSessionCacheShard> cacheShards) {
+            this.cacheShards = cacheShards;
+
+            // Set up metrics.
+            
FetchSessionCache.METRICS_GROUP.newGauge(FetchSession.NUM_INCREMENTAL_FETCH_SESSIONS,
+                () -> 
cacheShards.stream().mapToInt(FetchSessionCacheShard::size).sum());
+            
FetchSessionCache.METRICS_GROUP.newGauge(FetchSession.NUM_INCREMENTAL_FETCH_PARTITIONS_CACHED,
+                () -> 
cacheShards.stream().mapToLong(FetchSessionCacheShard::totalPartitions).sum());
+        }
+
+        public FetchSessionCacheShard getCacheShard(int sessionId) {
+            int shard = sessionId / cacheShards.get(0).sessionIdRange();
+            // This assumes that cacheShards is sorted by shardNum
+            return cacheShards.get(shard);
+        }
+
+        /**
+         * @return The shard in round-robin
+         */
+        public FetchSessionCacheShard getNextCacheShard() {
+            int shardNum = 
Utils.toPositive(FetchSessionCache.COUNTER.getAndIncrement()) % size();
+            return cacheShards.get(shardNum);
+        }
+
+        public int size() {
+            return cacheShards.size();
+        }
+    }
+}
diff --git 
a/server/src/main/java/org/apache/kafka/server/FetchSessionCacheShard.java 
b/server/src/main/java/org/apache/kafka/server/FetchSessionCacheShard.java
new file mode 100644
index 00000000000..e0ebe63d0d0
--- /dev/null
+++ b/server/src/main/java/org/apache/kafka/server/FetchSessionCacheShard.java
@@ -0,0 +1,297 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.server;
+
+import org.apache.kafka.common.requests.FetchMetadata;
+import org.apache.kafka.common.utils.ImplicitLinkedHashCollection;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.server.FetchSession.EvictableKey;
+import org.apache.kafka.server.FetchSession.LastUsedKey;
+
+import com.yammer.metrics.core.Meter;
+
+import org.slf4j.Logger;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Optional;
+import java.util.TreeMap;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Supplier;
+
+import static org.apache.kafka.common.requests.FetchMetadata.INITIAL_EPOCH;
+import static 
org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID;
+
+/**
+ * Caches fetch sessions.
+ * <p>
+ * See {@link #tryEvict} for an explanation of the cache eviction strategy.
+ * <p>
+ * The FetchSessionCache is thread-safe because all of its methods are 
synchronized.
+ * Note that individual fetch sessions have their own locks which are separate 
from the
+ * FetchSessionCache lock.  In order to avoid deadlock, the FetchSessionCache 
lock
+ * must never be acquired while an individual FetchSession lock is already 
held.
+ */
+public class FetchSessionCacheShard {
+    private final Logger logger;
+
+    private long numPartitions = 0;
+
+    /**
+     * A map of session ID to FetchSession.
+     */
+    private final Map<Integer, FetchSession> sessions = new HashMap<>();
+
+    /**
+     * Maps last used times to sessions.
+     */
+    private final TreeMap<LastUsedKey, FetchSession> lastUsed = new 
TreeMap<>();
+
+    /**
+     * A map containing sessions which can be evicted by both privileged and 
unprivileged sessions.
+     */
+    private final TreeMap<EvictableKey, FetchSession> evictableByAll = new 
TreeMap<>();
+
+    /**
+     * A map containing sessions which can be evicted by privileged sessions.
+     */
+    private final TreeMap<EvictableKey, FetchSession> evictableByPrivileged = 
new TreeMap<>();
+
+    /**
+     * This metric is shared across all shards because newMeter returns an 
existing metric
+     * if one exists with the same name. It's safe for concurrent use because 
Meter is thread-safe.
+     */
+    private final Meter evictionsMeter = 
FetchSession.FetchSessionCache.METRICS_GROUP.newMeter(
+        FetchSession.INCREMENTAL_FETCH_SESSIONS_EVICTIONS_PER_SEC,
+        FetchSession.EVICTIONS,
+        TimeUnit.SECONDS
+    );
+
+    private final int maxEntries;
+    private final long evictionMs;
+    private final int sessionIdRange;
+    private final int shardNum;
+
+    /**
+     * @param maxEntries The maximum number of entries that can be in the cache
+     * @param evictionMs The minimum time that an entry must be unused in 
order to be evictable
+     * @param sessionIdRange The number of sessionIds each cache shard handles.
+     *                       For a given instance, Math.max(1, shardNum * 
sessionIdRange) <= sessionId < (shardNum + 1) * sessionIdRange always holds.
+     * @param shardNum Identifier for this shard
+     */
+    public FetchSessionCacheShard(int maxEntries,
+                                  long evictionMs,
+                                  int sessionIdRange,
+                                  int shardNum) {
+        this.maxEntries = maxEntries;
+        this.evictionMs = evictionMs;
+        this.sessionIdRange = sessionIdRange;
+        this.shardNum = shardNum;
+        this.logger = new LogContext("[Shard " + shardNum + "] 
").logger(FetchSessionCacheShard.class);
+    }
+
+    public int sessionIdRange() {
+        return sessionIdRange;
+    }
+
+    // Only for testing
+    public Meter evictionsMeter() {
+        return evictionsMeter;
+    }
+
+    /**
+     * Get a session by session ID.
+     *
+     * @param sessionId  The session ID.
+     * @return           The session, or None if no such session was found.
+     */
+    public synchronized Optional<FetchSession> get(int sessionId) {
+        return Optional.ofNullable(sessions.get(sessionId));
+    }
+
+    /**
+     * Get the number of entries currently in the fetch session cache.
+     */
+    public synchronized int size() {
+        return sessions.size();
+    }
+
+    /**
+     * Get the total number of cached partitions.
+     */
+    public synchronized long totalPartitions() {
+        return numPartitions;
+    }
+
+    /**
+     * Creates a new random session ID.  The new session ID will be positive 
and unique on this broker.
+     *
+     * @return   The new session ID.
+     */
+    public synchronized int newSessionId() {
+        int id;
+        do {
+            id = ThreadLocalRandom.current().nextInt(Math.max(1, shardNum * 
sessionIdRange), (shardNum + 1) * sessionIdRange);
+        } while (sessions.containsKey(id) || id == INVALID_SESSION_ID);
+
+        return id;
+    }
+
+    /**
+     * Try to create a new session.
+     *
+     * @param now                The current time in milliseconds.
+     * @param privileged         True if the new entry we are trying to create 
is privileged.
+     * @param size               The number of cached partitions in the new 
entry we are trying to create.
+     * @param usesTopicIds       True if this session should use topic IDs.
+     * @param createPartitions   A callback function which creates the map of 
cached partitions and the mapping from
+     *                           topic name to topic ID for the topics.
+     * @return                   If we created a session, the ID; 
INVALID_SESSION_ID otherwise.
+     */
+    public synchronized int maybeCreateSession(long now,
+                                               boolean privileged,
+                                               int size,
+                                               boolean usesTopicIds,
+                                               
Supplier<ImplicitLinkedHashCollection<FetchSession.CachedPartition>> 
createPartitions) {
+        // If there is room, create a new session entry.
+        if ((sessions.size() < maxEntries) || tryEvict(privileged, new 
EvictableKey(privileged, size, 0), now)) {
+            ImplicitLinkedHashCollection<FetchSession.CachedPartition> 
partitionMap = createPartitions.get();
+            FetchSession session = new FetchSession(newSessionId(), 
privileged, partitionMap, usesTopicIds,
+                now, now, FetchMetadata.nextEpoch(INITIAL_EPOCH));
+            logger.debug("Created fetch session {}", session);
+            sessions.put(session.id(), session);
+            touch(session, now);
+
+            return session.id();
+        } else {
+            logger.debug("No fetch session created for privileged={}, 
size={}.", privileged, size);
+            return INVALID_SESSION_ID;
+        }
+    }
+
+    /**
+     * Try to evict an entry from the session cache.
+     * <p>
+     * A proposed new element A may evict an existing element B if:
+     * 1. A is privileged and B is not, or
+     * 2. B is considered "stale" because it has been inactive for a long 
time, or
+     * 3. A contains more partitions than B, and B is not recently created.
+     * <p>
+     * Prior to KAFKA-9401, the session cache was not sharded, and we looked 
at all
+     * entries while considering those eligible for eviction. Now eviction is 
done
+     * by considering entries on a per-shard basis.
+     *
+     * @param privileged True if the new entry we would like to add is 
privileged
+     * @param key        The EvictableKey for the new entry we would like to 
add
+     * @param now        The current time in milliseconds
+     * @return           True if an entry was evicted; false otherwise.
+     */
+    private synchronized boolean tryEvict(boolean privileged, EvictableKey 
key, long now) {
+        // Try to evict an entry which is stale.
+        Map.Entry<LastUsedKey, FetchSession> lastUsedEntry = 
lastUsed.firstEntry();
+        if (lastUsedEntry == null) {
+            logger.trace("There are no cache entries to evict.");
+            return false;
+        } else if (now - lastUsedEntry.getKey().lastUsedMs() > evictionMs) {
+            FetchSession session = lastUsedEntry.getValue();
+            logger.trace("Evicting stale FetchSession {}.", session.id());
+            remove(session);
+            evictionsMeter.mark();
+            return true;
+        } else {
+            // If there are no stale entries, check the first evictable entry.
+            // If it is less valuable than our proposed entry, evict it.
+            TreeMap<EvictableKey, FetchSession> map = privileged ? 
evictableByPrivileged : evictableByAll;
+            Map.Entry<EvictableKey, FetchSession> evictableEntry = 
map.firstEntry();
+            if (evictableEntry == null) {
+                logger.trace("No evictable entries found.");
+                return false;
+            } else if (key.compareTo(evictableEntry.getKey()) < 0) {
+                logger.trace("Can't evict {} with {}", 
evictableEntry.getKey(), key);
+                return false;
+            } else {
+                logger.trace("Evicting {} with {}.", evictableEntry.getKey(), 
key);
+                remove(evictableEntry.getValue());
+                evictionsMeter.mark();
+                return true;
+            }
+        }
+    }
+
+    public synchronized Optional<FetchSession> remove(int sessionId) {
+        Optional<FetchSession> session = get(sessionId);
+        return session.isPresent() ? remove(session.get()) : Optional.empty();
+    }
+
+    /**
+     * Remove an entry from the session cache.
+     *
+     * @param session  The session.
+     *
+     * @return         The removed session, or None if there was no such 
session.
+     */
+    public synchronized Optional<FetchSession> remove(FetchSession session) {
+        EvictableKey evictableKey;
+        synchronized (session) {
+            lastUsed.remove(session.lastUsedKey());
+            evictableKey = session.evictableKey();
+        }
+
+        evictableByAll.remove(evictableKey);
+        evictableByPrivileged.remove(evictableKey);
+        Optional<FetchSession> removeResult = 
Optional.ofNullable(sessions.remove(session.id()));
+
+        if (removeResult.isPresent())
+            numPartitions = numPartitions - session.cachedSize();
+
+        return removeResult;
+    }
+
+    /**
+     * Update a session's position in the lastUsed and evictable trees.
+     *
+     * @param session  The session
+     * @param now      The current time in milliseconds
+     */
+    public synchronized void touch(FetchSession session, long now) {
+        synchronized (session) {
+            // Update the lastUsed map.
+            lastUsed.remove(session.lastUsedKey());
+            session.setLastUsedMs(now);
+            lastUsed.put(session.lastUsedKey(), session);
+
+            int oldSize = session.cachedSize();
+            if (oldSize != -1) {
+                EvictableKey oldEvictableKey = session.evictableKey();
+                evictableByPrivileged.remove(oldEvictableKey);
+                evictableByAll.remove(oldEvictableKey);
+                numPartitions = numPartitions - oldSize;
+            }
+            session.setCachedSize(session.size());
+            EvictableKey newEvictableKey = session.evictableKey();
+
+            if ((!session.privileged()) || (now - session.creationMs() > 
evictionMs))
+                evictableByPrivileged.put(newEvictableKey, session);
+
+            if (now - session.creationMs() > evictionMs)
+                evictableByAll.put(newEvictableKey, session);
+
+            numPartitions = numPartitions + session.cachedSize();
+        }
+    }
+}


Reply via email to