This is an automated email from the ASF dual-hosted git repository.

dajac pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new a5f7c82a86 MINOR: Refactor `kafka.cluster.Replica` (#12081)
a5f7c82a86 is described below

commit a5f7c82a8674a3ae68b57571e8d9a564c87c7d54
Author: David Jacot <[email protected]>
AuthorDate: Mon Apr 25 21:43:32 2022 +0100

    MINOR: Refactor `kafka.cluster.Replica` (#12081)
    
    This patch refactors kafka.cluster.Replica, it usages and tests. This is 
part of the work in KAFKA-13790.
    
    Reviewers: Jason Gustafson <[email protected]>
---
 core/src/main/scala/kafka/cluster/Partition.scala  |  69 ++--
 core/src/main/scala/kafka/cluster/Replica.scala    | 166 +++++++---
 .../main/scala/kafka/server/ReplicaManager.scala   |  24 +-
 .../scala/unit/kafka/cluster/PartitionTest.scala   | 178 ++++++----
 .../scala/unit/kafka/cluster/ReplicaTest.scala     | 364 +++++++++++++++------
 .../test/scala/unit/kafka/log/UnifiedLogTest.scala |  85 +++++
 .../unit/kafka/server/ReplicaManagerTest.scala     |  16 +-
 7 files changed, 655 insertions(+), 247 deletions(-)

diff --git a/core/src/main/scala/kafka/cluster/Partition.scala 
b/core/src/main/scala/kafka/cluster/Partition.scala
index dffc5e9f17..f27a9cb558 100755
--- a/core/src/main/scala/kafka/cluster/Partition.scala
+++ b/core/src/main/scala/kafka/cluster/Partition.scala
@@ -561,6 +561,7 @@ class Partition(val topicPartition: TopicPartition,
           "Marking the topic partition as RECOVERED."
         )
       }
+
       updateAssignmentAndIsr(
         assignment = partitionState.replicas.asScala.map(_.toInt),
         isr = isr,
@@ -568,6 +569,7 @@ class Partition(val topicPartition: TopicPartition,
         removingReplicas = removingReplicas,
         LeaderRecoveryState.RECOVERED
       )
+
       try {
         createLogIfNotExists(partitionState.isNew, isFutureReplica = false, 
highWatermarkCheckpoints, topicId)
       } catch {
@@ -585,7 +587,7 @@ class Partition(val topicPartition: TopicPartition,
         s"ISR ${isr.mkString("[", ",", "]")} addingReplicas 
${addingReplicas.mkString("[", ",", "]")} " +
         s"removingReplicas ${removingReplicas.mkString("[", ",", "]")}. 
Previous leader epoch was $leaderEpoch.")
 
-      //We cache the leader epoch here, persisting it only if it's local 
(hence having a log dir)
+      // We cache the leader epoch here, persisting it only if it's local 
(hence having a log dir)
       leaderEpoch = partitionState.leaderEpoch
       leaderEpochStartOffsetOpt = Some(leaderEpochStartOffset)
       partitionEpoch = partitionState.partitionEpoch
@@ -598,31 +600,29 @@ class Partition(val topicPartition: TopicPartition,
       leaderLog.maybeAssignEpochStartOffset(leaderEpoch, 
leaderEpochStartOffset)
 
       val isNewLeader = !isLeader
-      val curTimeMs = time.milliseconds
-      // initialize lastCaughtUpTime of replicas as well as their 
lastFetchTimeMs and lastFetchLeaderLogEndOffset.
+      val currentTimeMs = time.milliseconds
+
+      // Initialize lastCaughtUpTime of replicas as well as their 
lastFetchTimeMs and
+      // lastFetchLeaderLogEndOffset.
       remoteReplicas.foreach { replica =>
-        val lastCaughtUpTimeMs = if 
(partitionState.isr.contains(replica.brokerId)) curTimeMs else 0L
-        replica.resetLastCaughtUpTime(leaderEpochStartOffset, curTimeMs, 
lastCaughtUpTimeMs)
+        replica.resetReplicaState(
+          currentTimeMs = currentTimeMs,
+          leaderEndOffset = leaderEpochStartOffset,
+          isNewLeader = isNewLeader,
+          isFollowerInSync = partitionState.isr.contains(replica.brokerId)
+        )
       }
 
-      if (isNewLeader) {
-        // mark local replica as the leader after converting hw
-        leaderReplicaIdOpt = Some(localBrokerId)
-        // reset log end offset for remote replicas
-        remoteReplicas.foreach { replica =>
-          replica.updateFetchState(
-            followerFetchOffsetMetadata = 
LogOffsetMetadata.UnknownOffsetMetadata,
-            followerStartOffset = UnifiedLog.UnknownOffset,
-            followerFetchTimeMs = 0L,
-            leaderEndOffset = UnifiedLog.UnknownOffset)
-        }
-      }
+      leaderReplicaIdOpt = Some(localBrokerId)
+
       // we may need to increment high watermark since ISR could be down to 1
-      (maybeIncrementLeaderHW(leaderLog), isNewLeader)
+      (maybeIncrementLeaderHW(leaderLog, currentTimeMs = currentTimeMs), 
isNewLeader)
     }
+
     // some delayed operations may be unblocked after HW changed
     if (leaderHWIncremented)
       tryCompleteDelayedRequests()
+
     isNewLeader
   }
 
@@ -693,7 +693,7 @@ class Partition(val topicPartition: TopicPartition,
       case Some(followerReplica) =>
         // No need to calculate low watermark if there is no delayed 
DeleteRecordsRequest
         val oldLeaderLW = if (delayedOperations.numDelayedDelete > 0) 
lowWatermarkIfLeader else -1L
-        val prevFollowerEndOffset = followerReplica.logEndOffset
+        val prevFollowerEndOffset = followerReplica.stateSnapshot.logEndOffset
         followerReplica.updateFetchState(
           followerFetchOffsetMetadata,
           followerStartOffset,
@@ -710,7 +710,7 @@ class Partition(val topicPartition: TopicPartition,
 
         // check if the HW of the partition can now be incremented
         // since the replica may already be in the ISR and its LEO has just 
incremented
-        val leaderHWIncremented = if (prevFollowerEndOffset != 
followerReplica.logEndOffset) {
+        val leaderHWIncremented = if (prevFollowerEndOffset != 
followerReplica.stateSnapshot.logEndOffset) {
           // the leader log may be updated by ReplicaAlterLogDirsThread so the 
following method must be in lock of
           // leaderIsrUpdateLock to prevent adding new hw to invalid log.
           inReadLock(leaderIsrUpdateLock) {
@@ -812,7 +812,7 @@ class Partition(val topicPartition: TopicPartition,
 
   private def isFollowerAtHighwatermark(followerReplica: Replica): Boolean = {
     leaderLogIfLocal.exists { leaderLog =>
-      val followerEndOffset = followerReplica.logEndOffset
+      val followerEndOffset = followerReplica.stateSnapshot.logEndOffset
       followerEndOffset >= leaderLog.highWatermark && 
leaderEpochStartOffsetOpt.exists(followerEndOffset >= _)
     }
   }
@@ -837,7 +837,7 @@ class Partition(val topicPartition: TopicPartition,
           }
 
           val curInSyncReplicaObjects = (curMaximalIsr - 
localBrokerId).flatMap(getReplica)
-          val replicaInfo = curInSyncReplicaObjects.map(replica => 
(replica.brokerId, replica.logEndOffset))
+          val replicaInfo = curInSyncReplicaObjects.map(replica => 
(replica.brokerId, replica.stateSnapshot.logEndOffset))
           val localLogInfo = (localBrokerId, localLogOrException.logEndOffset)
           val (ackedReplicas, awaitingReplicas) = (replicaInfo + 
localLogInfo).partition { _._2 >= requiredOffset}
 
@@ -886,15 +886,18 @@ class Partition(val topicPartition: TopicPartition,
    *
    * @return true if the HW was incremented, and false otherwise.
    */
-  private def maybeIncrementLeaderHW(leaderLog: UnifiedLog, curTime: Long = 
time.milliseconds): Boolean = {
+  private def maybeIncrementLeaderHW(leaderLog: UnifiedLog, currentTimeMs: 
Long = time.milliseconds): Boolean = {
     // maybeIncrementLeaderHW is in the hot path, the following code is 
written to
     // avoid unnecessary collection generation
-    var newHighWatermark = leaderLog.logEndOffsetMetadata
+    val leaderLogEndOffset = leaderLog.logEndOffsetMetadata
+    var newHighWatermark = leaderLogEndOffset
     remoteReplicasMap.values.foreach { replica =>
       // Note here we are using the "maximal", see explanation above
-      if (replica.logEndOffsetMetadata.messageOffset < 
newHighWatermark.messageOffset &&
-        (curTime - replica.lastCaughtUpTimeMs <= replicaLagTimeMaxMs || 
partitionState.maximalIsr.contains(replica.brokerId))) {
-        newHighWatermark = replica.logEndOffsetMetadata
+      val replicaState = replica.stateSnapshot
+      if (replicaState.logEndOffsetMetadata.messageOffset < 
newHighWatermark.messageOffset &&
+        (replicaState.isCaughtUp(leaderLogEndOffset.messageOffset, 
currentTimeMs, replicaLagTimeMaxMs)
+          || partitionState.maximalIsr.contains(replica.brokerId))) {
+        newHighWatermark = replicaState.logEndOffsetMetadata
       }
     }
 
@@ -909,7 +912,7 @@ class Partition(val topicPartition: TopicPartition,
         }
 
         if (isTraceEnabled) {
-          val replicaInfo = remoteReplicas.map(replica => (replica.brokerId, 
replica.logEndOffsetMetadata)).toSet
+          val replicaInfo = remoteReplicas.map(replica => (replica.brokerId, 
replica.stateSnapshot.logEndOffsetMetadata)).toSet
           val localLogInfo = (localBrokerId, 
localLogOrException.logEndOffsetMetadata)
           trace(s"Skipping update high watermark since new hw 
$newHighWatermark is not larger than old value. " +
             s"All current LEOs are ${(replicaInfo + 
localLogInfo).map(logEndOffsetString)}")
@@ -931,8 +934,9 @@ class Partition(val topicPartition: TopicPartition,
     // care has been taken to avoid generating unnecessary collections in this 
code
     var lowWaterMark = localLogOrException.logStartOffset
     remoteReplicas.foreach { replica =>
-      if (metadataCache.hasAliveBroker(replica.brokerId) && 
replica.logStartOffset < lowWaterMark) {
-        lowWaterMark = replica.logStartOffset
+      val logStartOffset = replica.stateSnapshot.logStartOffset
+      if (metadataCache.hasAliveBroker(replica.brokerId) && logStartOffset < 
lowWaterMark) {
+        lowWaterMark = logStartOffset
       }
     }
 
@@ -963,7 +967,7 @@ class Partition(val topicPartition: TopicPartition,
           if (!partitionState.isInflight && outOfSyncReplicaIds.nonEmpty) {
             val outOfSyncReplicaLog = outOfSyncReplicaIds.map { replicaId =>
               val logEndOffsetMessage = getReplica(replicaId)
-                .map(_.logEndOffset.toString)
+                .map(_.stateSnapshot.logEndOffset.toString)
                 .getOrElse("unknown")
               s"(brokerId: $replicaId, endOffset: $logEndOffsetMessage)"
             }.mkString(" ")
@@ -993,8 +997,7 @@ class Partition(val topicPartition: TopicPartition,
                                   currentTimeMs: Long,
                                   maxLagMs: Long): Boolean = {
     getReplica(replicaId).fold(true) { followerReplica =>
-      followerReplica.logEndOffset != leaderEndOffset &&
-        (currentTimeMs - followerReplica.lastCaughtUpTimeMs) > maxLagMs
+      !followerReplica.stateSnapshot.isCaughtUp(leaderEndOffset, 
currentTimeMs, maxLagMs)
     }
   }
 
diff --git a/core/src/main/scala/kafka/cluster/Replica.scala 
b/core/src/main/scala/kafka/cluster/Replica.scala
index 921faef061..0321488af4 100644
--- a/core/src/main/scala/kafka/cluster/Replica.scala
+++ b/core/src/main/scala/kafka/cluster/Replica.scala
@@ -13,7 +13,7 @@
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
-*/
+ */
 
 package kafka.cluster
 
@@ -22,35 +22,67 @@ import kafka.server.LogOffsetMetadata
 import kafka.utils.Logging
 import org.apache.kafka.common.TopicPartition
 
-class Replica(val brokerId: Int, val topicPartition: TopicPartition) extends 
Logging {
-  // the log end offset value, kept in all replicas;
-  // for local replica it is the log's end offset, for remote replicas its 
value is only updated by follower fetch
-  @volatile private[this] var _logEndOffsetMetadata = 
LogOffsetMetadata.UnknownOffsetMetadata
-  // the log start offset value, kept in all replicas;
-  // for local replica it is the log's start offset, for remote replicas its 
value is only updated by follower fetch
-  @volatile private[this] var _logStartOffset = UnifiedLog.UnknownOffset
+import java.util.concurrent.atomic.AtomicReference
+
+case class ReplicaState(
+  // The log start offset value, kept in all replicas; for local replica it is 
the
+  // log's start offset, for remote replicas its value is only updated by 
follower fetch.
+  logStartOffset: Long,
+
+  // The log end offset value, kept in all replicas; for local replica it is 
the
+  // log's end offset, for remote replicas its value is only updated by 
follower fetch.
+  logEndOffsetMetadata: LogOffsetMetadata,
 
-  // The log end offset value at the time the leader received the last 
FetchRequest from this follower
-  // This is used to determine the lastCaughtUpTimeMs of the follower
-  @volatile private[this] var lastFetchLeaderLogEndOffset = 0L
+  // The log end offset value at the time the leader received the last 
FetchRequest from this follower.
+  // This is used to determine the lastCaughtUpTimeMs of the follower. It is 
reset by the leader
+  // when a LeaderAndIsr request is received and might be reset when the 
leader appends a record
+  // to its log.
+  lastFetchLeaderLogEndOffset: Long,
 
-  // The time when the leader received the last FetchRequest from this follower
-  // This is used to determine the lastCaughtUpTimeMs of the follower
-  @volatile private[this] var lastFetchTimeMs = 0L
+  // The time when the leader received the last FetchRequest from this 
follower.
+  // This is used to determine the lastCaughtUpTimeMs of the follower.
+  lastFetchTimeMs: Long,
 
   // lastCaughtUpTimeMs is the largest time t such that the offset of most 
recent FetchRequest from this follower >=
   // the LEO of leader at time t. This is used to determine the lag of this 
follower and ISR of this partition.
-  @volatile private[this] var _lastCaughtUpTimeMs = 0L
+  lastCaughtUpTimeMs: Long
+) {
+  /**
+   * Returns the current log end offset of the replica.
+   */
+  def logEndOffset: Long = logEndOffsetMetadata.messageOffset
 
-  def logStartOffset: Long = _logStartOffset
+  /**
+   * Returns true when the replica is considered as "caught-up". A replica is
+   * considered "caught-up" when its log end offset is equals to the log end
+   * offset of the leader OR when its last caught up time minus the current
+   * time is smaller than the max replica lag.
+   */
+  def isCaughtUp(
+    leaderEndOffset: Long,
+    currentTimeMs: Long,
+    replicaMaxLagMs: Long
+  ): Boolean = {
+    leaderEndOffset == logEndOffset || currentTimeMs - lastCaughtUpTimeMs <= 
replicaMaxLagMs
+  }
+}
 
-  def logEndOffsetMetadata: LogOffsetMetadata = _logEndOffsetMetadata
+object ReplicaState {
+  val Empty: ReplicaState = ReplicaState(
+    logEndOffsetMetadata = LogOffsetMetadata.UnknownOffsetMetadata,
+    logStartOffset = UnifiedLog.UnknownOffset,
+    lastFetchLeaderLogEndOffset = 0L,
+    lastFetchTimeMs = 0L,
+    lastCaughtUpTimeMs = 0L
+  )
+}
 
-  def logEndOffset: Long = logEndOffsetMetadata.messageOffset
+class Replica(val brokerId: Int, val topicPartition: TopicPartition) extends 
Logging {
+  private val replicaState = new 
AtomicReference[ReplicaState](ReplicaState.Empty)
 
-  def lastCaughtUpTimeMs: Long = _lastCaughtUpTimeMs
+  def stateSnapshot: ReplicaState = replicaState.get
 
-  /*
+  /**
    * If the FetchRequest reads up to the log end offset of the leader when the 
current fetch request is received,
    * set `lastCaughtUpTimeMs` to the time when the current fetch request was 
received.
    *
@@ -62,39 +94,85 @@ class Replica(val brokerId: Int, val topicPartition: 
TopicPartition) extends Log
    * fetch request is always smaller than the leader's LEO, which can happen 
if small produce requests are received at
    * high frequency.
    */
-  def updateFetchState(followerFetchOffsetMetadata: LogOffsetMetadata,
-                       followerStartOffset: Long,
-                       followerFetchTimeMs: Long,
-                       leaderEndOffset: Long): Unit = {
-    if (followerFetchOffsetMetadata.messageOffset >= leaderEndOffset)
-      _lastCaughtUpTimeMs = math.max(_lastCaughtUpTimeMs, followerFetchTimeMs)
-    else if (followerFetchOffsetMetadata.messageOffset >= 
lastFetchLeaderLogEndOffset)
-      _lastCaughtUpTimeMs = math.max(_lastCaughtUpTimeMs, lastFetchTimeMs)
-
-    _logStartOffset = followerStartOffset
-    _logEndOffsetMetadata = followerFetchOffsetMetadata
-    lastFetchLeaderLogEndOffset = leaderEndOffset
-    lastFetchTimeMs = followerFetchTimeMs
+  def updateFetchState(
+    followerFetchOffsetMetadata: LogOffsetMetadata,
+    followerStartOffset: Long,
+    followerFetchTimeMs: Long,
+    leaderEndOffset: Long
+  ): Unit = {
+    replicaState.updateAndGet { currentReplicaState =>
+      val lastCaughtUpTime = if (followerFetchOffsetMetadata.messageOffset >= 
leaderEndOffset) {
+        math.max(currentReplicaState.lastCaughtUpTimeMs, followerFetchTimeMs)
+      } else if (followerFetchOffsetMetadata.messageOffset >= 
currentReplicaState.lastFetchLeaderLogEndOffset) {
+        math.max(currentReplicaState.lastCaughtUpTimeMs, 
currentReplicaState.lastFetchTimeMs)
+      } else {
+        currentReplicaState.lastCaughtUpTimeMs
+      }
+
+      ReplicaState(
+        logStartOffset = followerStartOffset,
+        logEndOffsetMetadata = followerFetchOffsetMetadata,
+        lastFetchLeaderLogEndOffset = math.max(leaderEndOffset, 
currentReplicaState.lastFetchLeaderLogEndOffset),
+        lastFetchTimeMs = followerFetchTimeMs,
+        lastCaughtUpTimeMs = lastCaughtUpTime
+      )
+    }
   }
 
-  def resetLastCaughtUpTime(curLeaderLogEndOffset: Long, curTimeMs: Long, 
lastCaughtUpTimeMs: Long): Unit = {
-    lastFetchLeaderLogEndOffset = curLeaderLogEndOffset
-    lastFetchTimeMs = curTimeMs
-    _lastCaughtUpTimeMs = lastCaughtUpTimeMs
+  /**
+   * When the leader is elected or re-elected, the state of the follower is 
reinitialized
+   * accordingly.
+   */
+  def resetReplicaState(
+    currentTimeMs: Long,
+    leaderEndOffset: Long,
+    isNewLeader: Boolean,
+    isFollowerInSync: Boolean
+  ): Unit = {
+    replicaState.updateAndGet { currentReplicaState =>
+      // When the leader is elected or re-elected, the follower's last caught 
up time
+      // is set to the current time if the follower is in the ISR, else to 0. 
The latter
+      // is done to ensure that the high watermark is not hold back 
unnecessarily for
+      // a follower which is not in the ISR anymore.
+      val lastCaughtUpTimeMs = if (isFollowerInSync) currentTimeMs else 0L
+
+      if (isNewLeader) {
+        ReplicaState(
+          logStartOffset = UnifiedLog.UnknownOffset,
+          logEndOffsetMetadata = LogOffsetMetadata.UnknownOffsetMetadata,
+          lastFetchLeaderLogEndOffset = UnifiedLog.UnknownOffset,
+          lastFetchTimeMs = 0L,
+          lastCaughtUpTimeMs = lastCaughtUpTimeMs
+        )
+      } else {
+        ReplicaState(
+          logStartOffset = currentReplicaState.logStartOffset,
+          logEndOffsetMetadata = currentReplicaState.logEndOffsetMetadata,
+          lastFetchLeaderLogEndOffset = leaderEndOffset,
+          // When the leader is re-elected, the follower's last fetch time is
+          // set to the current time if the follower is in the ISR, else to 0.
+          // The latter is done to ensure that the follower is not brought back
+          // into the ISR before a fetch is received.
+          lastFetchTimeMs = if (isFollowerInSync) currentTimeMs else 0L,
+          lastCaughtUpTimeMs = lastCaughtUpTimeMs
+        )
+      }
+    }
     trace(s"Reset state of replica to $this")
   }
 
   override def toString: String = {
+    val replicaState = this.replicaState.get
     val replicaString = new StringBuilder
-    replicaString.append("Replica(replicaId=" + brokerId)
+    replicaString.append(s"Replica(replicaId=$brokerId")
     replicaString.append(s", topic=${topicPartition.topic}")
     replicaString.append(s", partition=${topicPartition.partition}")
-    replicaString.append(s", lastCaughtUpTimeMs=$lastCaughtUpTimeMs")
-    replicaString.append(s", logStartOffset=$logStartOffset")
-    replicaString.append(s", logEndOffset=$logEndOffset")
-    replicaString.append(s", logEndOffsetMetadata=$logEndOffsetMetadata")
-    replicaString.append(s", 
lastFetchLeaderLogEndOffset=$lastFetchLeaderLogEndOffset")
-    replicaString.append(s", lastFetchTimeMs=$lastFetchTimeMs")
+    replicaString.append(s", 
lastCaughtUpTimeMs=${replicaState.lastCaughtUpTimeMs}")
+    replicaString.append(s", logStartOffset=${replicaState.logStartOffset}")
+    replicaString.append(s", 
logEndOffset=${replicaState.logEndOffsetMetadata.messageOffset}")
+    replicaString.append(s", 
logEndOffsetMetadata=${replicaState.logEndOffsetMetadata}")
+    replicaString.append(s", 
lastFetchLeaderLogEndOffset=${replicaState.lastFetchLeaderLogEndOffset}")
+    replicaString.append(s", lastFetchTimeMs=${replicaState.lastFetchTimeMs}")
     replicaString.append(")")
     replicaString.toString
   }
diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala 
b/core/src/main/scala/kafka/server/ReplicaManager.scala
index 831212f9eb..d72c7351f6 100644
--- a/core/src/main/scala/kafka/server/ReplicaManager.scala
+++ b/core/src/main/scala/kafka/server/ReplicaManager.scala
@@ -1241,18 +1241,26 @@ class ReplicaManager(val config: KafkaConfig,
         replicaSelectorOpt.flatMap { replicaSelector =>
           val replicaEndpoints = 
metadataCache.getPartitionReplicaEndpoints(partition.topicPartition,
             new ListenerName(clientMetadata.listenerName))
-          val replicaInfos = partition.remoteReplicas
+          val replicaInfoSet = mutable.Set[ReplicaView]()
+
+          partition.remoteReplicas.foreach { replica =>
+            val replicaState = replica.stateSnapshot
             // Exclude replicas that don't have the requested offset (whether 
or not if they're in the ISR)
-            .filter(replica => replica.logEndOffset >= fetchOffset && 
replica.logStartOffset <= fetchOffset)
-            .map(replica => new DefaultReplicaView(
-              replicaEndpoints.getOrElse(replica.brokerId, Node.noNode()),
-              replica.logEndOffset,
-              currentTimeMs - replica.lastCaughtUpTimeMs))
+            if (replicaState.logEndOffset >= fetchOffset && 
replicaState.logStartOffset <= fetchOffset) {
+              replicaInfoSet.add(new DefaultReplicaView(
+                replicaEndpoints.getOrElse(replica.brokerId, Node.noNode()),
+                replicaState.logEndOffset,
+                currentTimeMs - replicaState.lastCaughtUpTimeMs
+              ))
+            }
+          }
 
           val leaderReplica = new DefaultReplicaView(
             replicaEndpoints.getOrElse(leaderReplicaId, Node.noNode()),
-            partition.localLogOrException.logEndOffset, 0L)
-          val replicaInfoSet = mutable.Set[ReplicaView]() ++= replicaInfos += 
leaderReplica
+            partition.localLogOrException.logEndOffset,
+            0L
+          )
+          replicaInfoSet.add(leaderReplica)
 
           val partitionInfo = new DefaultPartitionView(replicaInfoSet.asJava, 
leaderReplica)
           replicaSelector.select(partition.topicPartition, clientMetadata, 
partitionInfo).asScala.collect {
diff --git a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala 
b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
index 0b949945eb..ec0ad044f7 100644
--- a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
+++ b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
@@ -1047,10 +1047,11 @@ class PartitionTest extends AbstractPartitionTest {
           .setIsNew(true),
         offsetCheckpoints, None), "Expected become leader transition to 
succeed")
 
-    val remoteReplica = partition.getReplica(remoteBrokerId).get
-    assertEquals(initializeTimeMs, remoteReplica.lastCaughtUpTimeMs)
-    assertEquals(LogOffsetMetadata.UnknownOffsetMetadata.messageOffset, 
remoteReplica.logEndOffset)
-    assertEquals(UnifiedLog.UnknownOffset, remoteReplica.logStartOffset)
+    assertReplicaState(partition, remoteBrokerId,
+      lastCaughtUpTimeMs = initializeTimeMs,
+      logStartOffset = UnifiedLog.UnknownOffset,
+      logEndOffset = UnifiedLog.UnknownOffset
+    )
 
     time.sleep(500)
 
@@ -1060,9 +1061,11 @@ class PartitionTest extends AbstractPartitionTest {
       followerFetchTimeMs = time.milliseconds(),
       leaderEndOffset = 6L)
 
-    assertEquals(initializeTimeMs, remoteReplica.lastCaughtUpTimeMs)
-    assertEquals(3L, remoteReplica.logEndOffset)
-    assertEquals(0L, remoteReplica.logStartOffset)
+    assertReplicaState(partition, remoteBrokerId,
+      lastCaughtUpTimeMs = initializeTimeMs,
+      logStartOffset = 0L,
+      logEndOffset = 3L
+    )
 
     time.sleep(500)
 
@@ -1072,10 +1075,11 @@ class PartitionTest extends AbstractPartitionTest {
       followerFetchTimeMs = time.milliseconds(),
       leaderEndOffset = 6L)
 
-    assertEquals(time.milliseconds(), remoteReplica.lastCaughtUpTimeMs)
-    assertEquals(6L, remoteReplica.logEndOffset)
-    assertEquals(0L, remoteReplica.logStartOffset)
-
+    assertReplicaState(partition, remoteBrokerId,
+      lastCaughtUpTimeMs = time.milliseconds(),
+      logStartOffset = 0L,
+      logEndOffset = 6L
+    )
   }
 
   @Test
@@ -1102,9 +1106,11 @@ class PartitionTest extends AbstractPartitionTest {
         offsetCheckpoints, None), "Expected become leader transition to 
succeed")
     assertEquals(Set(brokerId), partition.partitionState.isr)
 
-    val remoteReplica = partition.getReplica(remoteBrokerId).get
-    assertEquals(LogOffsetMetadata.UnknownOffsetMetadata.messageOffset, 
remoteReplica.logEndOffset)
-    assertEquals(UnifiedLog.UnknownOffset, remoteReplica.logStartOffset)
+    assertReplicaState(partition, remoteBrokerId,
+      lastCaughtUpTimeMs = 0L,
+      logStartOffset = UnifiedLog.UnknownOffset,
+      logEndOffset = UnifiedLog.UnknownOffset
+    )
 
     partition.updateFollowerFetchState(remoteBrokerId,
       followerFetchOffsetMetadata = LogOffsetMetadata(10),
@@ -1155,9 +1161,11 @@ class PartitionTest extends AbstractPartitionTest {
         offsetCheckpoints, None), "Expected become leader transition to 
succeed")
     assertEquals(Set(brokerId), partition.partitionState.isr)
 
-    val remoteReplica = partition.getReplica(remoteBrokerId).get
-    assertEquals(LogOffsetMetadata.UnknownOffsetMetadata.messageOffset, 
remoteReplica.logEndOffset)
-    assertEquals(UnifiedLog.UnknownOffset, remoteReplica.logStartOffset)
+    assertReplicaState(partition, remoteBrokerId,
+      lastCaughtUpTimeMs = 0L,
+      logStartOffset = UnifiedLog.UnknownOffset,
+      logEndOffset = UnifiedLog.UnknownOffset
+    )
 
     partition.updateFollowerFetchState(remoteBrokerId,
       followerFetchOffsetMetadata = LogOffsetMetadata(3),
@@ -1166,8 +1174,11 @@ class PartitionTest extends AbstractPartitionTest {
       leaderEndOffset = 6L)
 
     assertEquals(Set(brokerId), partition.partitionState.isr)
-    assertEquals(3L, remoteReplica.logEndOffset)
-    assertEquals(0L, remoteReplica.logStartOffset)
+    assertReplicaState(partition, remoteBrokerId,
+      lastCaughtUpTimeMs = 0L,
+      logStartOffset = 0L,
+      logEndOffset = 3L
+    )
 
     partition.updateFollowerFetchState(remoteBrokerId,
       followerFetchOffsetMetadata = LogOffsetMetadata(10),
@@ -1180,8 +1191,11 @@ class PartitionTest extends AbstractPartitionTest {
     assertEquals(isrItem.leaderAndIsr.isr, List(brokerId, remoteBrokerId))
     assertEquals(Set(brokerId), partition.partitionState.isr)
     assertEquals(Set(brokerId, remoteBrokerId), 
partition.partitionState.maximalIsr)
-    assertEquals(10L, remoteReplica.logEndOffset)
-    assertEquals(0L, remoteReplica.logStartOffset)
+    assertReplicaState(partition, remoteBrokerId,
+      lastCaughtUpTimeMs = time.milliseconds(),
+      logStartOffset = 0L,
+      logEndOffset = 10L
+    )
 
     // Complete the ISR expansion
     alterIsrManager.completeIsrUpdate(2)
@@ -1216,9 +1230,11 @@ class PartitionTest extends AbstractPartitionTest {
         offsetCheckpoints, None), "Expected become leader transition to 
succeed")
     assertEquals(Set(brokerId), partition.partitionState.isr)
 
-    val remoteReplica = partition.getReplica(remoteBrokerId).get
-    assertEquals(LogOffsetMetadata.UnknownOffsetMetadata.messageOffset, 
remoteReplica.logEndOffset)
-    assertEquals(UnifiedLog.UnknownOffset, remoteReplica.logStartOffset)
+    assertReplicaState(partition, remoteBrokerId,
+      lastCaughtUpTimeMs = 0L,
+      logStartOffset = UnifiedLog.UnknownOffset,
+      logEndOffset = UnifiedLog.UnknownOffset
+    )
 
     partition.updateFollowerFetchState(remoteBrokerId,
       followerFetchOffsetMetadata = LogOffsetMetadata(10),
@@ -1230,8 +1246,11 @@ class PartitionTest extends AbstractPartitionTest {
     assertEquals(Set(brokerId), partition.inSyncReplicaIds)
     assertEquals(Set(brokerId, remoteBrokerId), 
partition.partitionState.maximalIsr)
     assertEquals(alterIsrManager.isrUpdates.size, 1)
-    assertEquals(10L, remoteReplica.logEndOffset)
-    assertEquals(0L, remoteReplica.logStartOffset)
+    assertReplicaState(partition, remoteBrokerId,
+      lastCaughtUpTimeMs = time.milliseconds(),
+      logStartOffset = 0L,
+      logEndOffset = 10L
+    )
 
     // Simulate failure callback
     alterIsrManager.failIsrUpdate(Errors.INVALID_UPDATE_VERSION)
@@ -1321,10 +1340,11 @@ class PartitionTest extends AbstractPartitionTest {
     ))
     assertEquals(0L, partition.localLogOrException.highWatermark)
 
-    val remoteReplica = partition.getReplica(remoteBrokerId).get
-    assertEquals(initializeTimeMs, remoteReplica.lastCaughtUpTimeMs)
-    assertEquals(LogOffsetMetadata.UnknownOffsetMetadata.messageOffset, 
remoteReplica.logEndOffset)
-    assertEquals(UnifiedLog.UnknownOffset, remoteReplica.logStartOffset)
+    assertReplicaState(partition, remoteBrokerId,
+      lastCaughtUpTimeMs = initializeTimeMs,
+      logStartOffset = UnifiedLog.UnknownOffset,
+      logEndOffset = UnifiedLog.UnknownOffset
+    )
 
     // On initialization, the replica is considered caught up and should not 
be removed
     partition.maybeShrinkIsr()
@@ -1376,10 +1396,11 @@ class PartitionTest extends AbstractPartitionTest {
     ))
     assertEquals(0L, partition.localLogOrException.highWatermark)
 
-    val remoteReplica = partition.getReplica(remoteBrokerId).get
-    assertEquals(initializeTimeMs, remoteReplica.lastCaughtUpTimeMs)
-    assertEquals(LogOffsetMetadata.UnknownOffsetMetadata.messageOffset, 
remoteReplica.logEndOffset)
-    assertEquals(UnifiedLog.UnknownOffset, remoteReplica.logStartOffset)
+    assertReplicaState(partition, remoteBrokerId,
+      lastCaughtUpTimeMs = initializeTimeMs,
+      logStartOffset = UnifiedLog.UnknownOffset,
+      logEndOffset = UnifiedLog.UnknownOffset
+    )
 
     // Shrink the ISR
     time.sleep(partition.replicaLagTimeMaxMs + 1)
@@ -1433,10 +1454,11 @@ class PartitionTest extends AbstractPartitionTest {
     ))
     assertEquals(0L, partition.localLogOrException.highWatermark)
 
-    val remoteReplica = partition.getReplica(remoteBrokerId).get
-    assertEquals(initializeTimeMs, remoteReplica.lastCaughtUpTimeMs)
-    assertEquals(LogOffsetMetadata.UnknownOffsetMetadata.messageOffset, 
remoteReplica.logEndOffset)
-    assertEquals(UnifiedLog.UnknownOffset, remoteReplica.logStartOffset)
+    assertReplicaState(partition, remoteBrokerId,
+      lastCaughtUpTimeMs = initializeTimeMs,
+      logStartOffset = UnifiedLog.UnknownOffset,
+      logEndOffset = UnifiedLog.UnknownOffset
+    )
 
     // There is a short delay before the first fetch. The follower is not yet 
caught up to the log end.
     time.sleep(5000)
@@ -1446,10 +1468,12 @@ class PartitionTest extends AbstractPartitionTest {
       followerStartOffset = 0L,
       followerFetchTimeMs = firstFetchTimeMs,
       leaderEndOffset = 10L)
-    assertEquals(initializeTimeMs, remoteReplica.lastCaughtUpTimeMs)
+    assertReplicaState(partition, remoteBrokerId,
+      lastCaughtUpTimeMs = initializeTimeMs,
+      logStartOffset = 0L,
+      logEndOffset = 5L
+    )
     assertEquals(5L, partition.localLogOrException.highWatermark)
-    assertEquals(5L, remoteReplica.logEndOffset)
-    assertEquals(0L, remoteReplica.logStartOffset)
 
     // Some new data is appended, but the follower catches up to the old end 
offset.
     // The total elapsed time from initialization is larger than the max 
allowed replica lag.
@@ -1460,10 +1484,12 @@ class PartitionTest extends AbstractPartitionTest {
       followerStartOffset = 0L,
       followerFetchTimeMs = time.milliseconds(),
       leaderEndOffset = 15L)
-    assertEquals(firstFetchTimeMs, remoteReplica.lastCaughtUpTimeMs)
+    assertReplicaState(partition, remoteBrokerId,
+      lastCaughtUpTimeMs = firstFetchTimeMs,
+      logStartOffset = 0L,
+      logEndOffset = 10L
+    )
     assertEquals(10L, partition.localLogOrException.highWatermark)
-    assertEquals(10L, remoteReplica.logEndOffset)
-    assertEquals(0L, remoteReplica.logStartOffset)
 
     // The ISR should not be shrunk because the follower has caught up with 
the leader at the
     // time of the first fetch.
@@ -1495,10 +1521,11 @@ class PartitionTest extends AbstractPartitionTest {
     ))
     assertEquals(0L, partition.localLogOrException.highWatermark)
 
-    val remoteReplica = partition.getReplica(remoteBrokerId).get
-    assertEquals(initializeTimeMs, remoteReplica.lastCaughtUpTimeMs)
-    assertEquals(LogOffsetMetadata.UnknownOffsetMetadata.messageOffset, 
remoteReplica.logEndOffset)
-    assertEquals(UnifiedLog.UnknownOffset, remoteReplica.logStartOffset)
+    assertReplicaState(partition, remoteBrokerId,
+      lastCaughtUpTimeMs = initializeTimeMs,
+      logStartOffset = UnifiedLog.UnknownOffset,
+      logEndOffset = UnifiedLog.UnknownOffset
+    )
 
     // The follower catches up to the log end immediately.
     partition.updateFollowerFetchState(remoteBrokerId,
@@ -1506,10 +1533,12 @@ class PartitionTest extends AbstractPartitionTest {
       followerStartOffset = 0L,
       followerFetchTimeMs = time.milliseconds(),
       leaderEndOffset = 10L)
-    assertEquals(initializeTimeMs, remoteReplica.lastCaughtUpTimeMs)
+    assertReplicaState(partition, remoteBrokerId,
+      lastCaughtUpTimeMs = time.milliseconds(),
+      logStartOffset = 0L,
+      logEndOffset = 10L
+    )
     assertEquals(10L, partition.localLogOrException.highWatermark)
-    assertEquals(10L, remoteReplica.logEndOffset)
-    assertEquals(0L, remoteReplica.logStartOffset)
 
     // Sleep longer than the max allowed follower lag
     time.sleep(30001)
@@ -1543,10 +1572,11 @@ class PartitionTest extends AbstractPartitionTest {
     ))
     assertEquals(0L, partition.localLogOrException.highWatermark)
 
-    val remoteReplica = partition.getReplica(remoteBrokerId).get
-    assertEquals(initializeTimeMs, remoteReplica.lastCaughtUpTimeMs)
-    assertEquals(LogOffsetMetadata.UnknownOffsetMetadata.messageOffset, 
remoteReplica.logEndOffset)
-    assertEquals(UnifiedLog.UnknownOffset, remoteReplica.logStartOffset)
+    assertReplicaState(partition, remoteBrokerId,
+      lastCaughtUpTimeMs = initializeTimeMs,
+      logStartOffset = UnifiedLog.UnknownOffset,
+      logEndOffset = UnifiedLog.UnknownOffset
+    )
 
     time.sleep(30001)
 
@@ -1618,11 +1648,14 @@ class PartitionTest extends AbstractPartitionTest {
     ))
     assertEquals(10L, partition.localLogOrException.highWatermark)
 
-    val remoteReplica = partition.getReplica(remoteBrokerId).get
-    assertEquals(LogOffsetMetadata.UnknownOffsetMetadata.messageOffset, 
remoteReplica.logEndOffset)
-    assertEquals(UnifiedLog.UnknownOffset, remoteReplica.logStartOffset)
+    assertReplicaState(partition, remoteBrokerId,
+      lastCaughtUpTimeMs = 0L,
+      logStartOffset = UnifiedLog.UnknownOffset,
+      logEndOffset = UnifiedLog.UnknownOffset
+    )
 
     // This will attempt to expand the ISR
+    val firstFetchTimeMs = time.milliseconds()
     partition.updateFollowerFetchState(remoteBrokerId,
       followerFetchOffsetMetadata = LogOffsetMetadata(10),
       followerStartOffset = 0L,
@@ -1633,8 +1666,11 @@ class PartitionTest extends AbstractPartitionTest {
     assertEquals(Set(brokerId), partition.inSyncReplicaIds)
     assertEquals(Set(brokerId, remoteBrokerId), 
partition.partitionState.maximalIsr)
     assertEquals(alterIsrManager.isrUpdates.size, 1)
-    assertEquals(10L, remoteReplica.logEndOffset)
-    assertEquals(0L, remoteReplica.logStartOffset)
+    assertReplicaState(partition, remoteBrokerId,
+      lastCaughtUpTimeMs = firstFetchTimeMs,
+      logStartOffset = 0L,
+      logEndOffset = 10L
+    )
 
     // Failure
     alterIsrManager.failIsrUpdate(error)
@@ -2137,4 +2173,26 @@ class PartitionTest extends AbstractPartitionTest {
       appendInfo
     }
   }
+
+  private def assertReplicaState(
+    partition: Partition,
+    replicaId: Int,
+    lastCaughtUpTimeMs: Long,
+    logEndOffset: Long,
+    logStartOffset: Long
+  ): Unit = {
+    partition.getReplica(replicaId) match {
+      case Some(replica) =>
+        val replicaState = replica.stateSnapshot
+        assertEquals(lastCaughtUpTimeMs, replicaState.lastCaughtUpTimeMs,
+          "Unexpected Last Caught Up Time")
+        assertEquals(logEndOffset, replicaState.logEndOffset,
+          "Unexpected Log End Offset")
+        assertEquals(logStartOffset, replicaState.logStartOffset,
+          "Unexpected Log Start Offset")
+
+      case None =>
+        fail(s"Replica $replicaId not found.")
+    }
+  }
 }
diff --git a/core/src/test/scala/unit/kafka/cluster/ReplicaTest.scala 
b/core/src/test/scala/unit/kafka/cluster/ReplicaTest.scala
index 201ec1dea5..76910642ae 100644
--- a/core/src/test/scala/unit/kafka/cluster/ReplicaTest.scala
+++ b/core/src/test/scala/unit/kafka/cluster/ReplicaTest.scala
@@ -1,4 +1,4 @@
-/*
+/**
  * Licensed to the Apache Software Foundation (ASF) under one or more
  * contributor license agreements.  See the NOTICE file distributed with
  * this work for additional information regarding copyright ownership.
@@ -6,7 +6,7 @@
  * (the "License"); you may not use this file except in compliance with
  * the License.  You may obtain a copy of the License at
  *
- *      http://www.apache.org/licenses/LICENSE-2.0
+ *    http://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing, software
  * distributed under the License is distributed on an "AS IS" BASIS,
@@ -16,117 +16,293 @@
  */
 package kafka.cluster
 
-import java.util.Properties
+import kafka.log.UnifiedLog
+import kafka.server.LogOffsetMetadata
+import kafka.utils.MockTime
+import org.apache.kafka.common.TopicPartition
+import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertTrue}
+import org.junit.jupiter.api.{BeforeEach, Test}
 
-import kafka.log.{ClientRecordDeletion, UnifiedLog, LogConfig, LogManager}
-import kafka.server.{BrokerTopicStats, LogDirFailureChannel}
-import kafka.utils.{MockTime, TestUtils}
-import org.apache.kafka.common.errors.OffsetOutOfRangeException
-import org.apache.kafka.common.utils.Utils
-import org.junit.jupiter.api.Assertions._
-import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
+object ReplicaTest {
+  val BrokerId: Int = 0
+  val Partition: TopicPartition = new TopicPartition("foo", 0)
+  val ReplicaLagTimeMaxMs: Long = 30000
+}
 
 class ReplicaTest {
+  import ReplicaTest._
 
-  val tmpDir = TestUtils.tempDir()
-  val logDir = TestUtils.randomPartitionLogDir(tmpDir)
   val time = new MockTime()
-  val brokerTopicStats = new BrokerTopicStats
-  var log: UnifiedLog = _
+  var replica: Replica = _
 
   @BeforeEach
   def setup(): Unit = {
-    val logProps = new Properties()
-    logProps.put(LogConfig.SegmentBytesProp, 512: java.lang.Integer)
-    logProps.put(LogConfig.SegmentIndexBytesProp, 1000: java.lang.Integer)
-    logProps.put(LogConfig.RetentionMsProp, 999: java.lang.Integer)
-    val config = LogConfig(logProps)
-    log = UnifiedLog(
-      dir = logDir,
-      config = config,
-      logStartOffset = 0L,
-      recoveryPoint = 0L,
-      scheduler = time.scheduler,
-      brokerTopicStats = brokerTopicStats,
-      time = time,
-      maxTransactionTimeoutMs = 5 * 60 * 1000,
-      maxProducerIdExpirationMs = 60 * 60 * 1000,
-      producerIdExpirationCheckIntervalMs = 
LogManager.ProducerIdExpirationCheckIntervalMs,
-      logDirFailureChannel = new LogDirFailureChannel(10),
-      topicId = None,
-      keepPartitionMetadataFile = true
+    replica = new Replica(BrokerId, Partition)
+  }
+
+  private def assertReplicaState(
+    logStartOffset: Long,
+    logEndOffset: Long,
+    lastCaughtUpTimeMs: Long,
+    lastFetchLeaderLogEndOffset: Long,
+    lastFetchTimeMs: Long
+  ): Unit = {
+    val replicaState = replica.stateSnapshot
+    assertEquals(logStartOffset, replicaState.logStartOffset,
+      "Unexpected Log Start Offset")
+    assertEquals(logEndOffset, replicaState.logEndOffset,
+      "Unexpected Log End Offset")
+    assertEquals(lastCaughtUpTimeMs, replicaState.lastCaughtUpTimeMs,
+      "Unexpected Last Caught Up Time")
+    assertEquals(lastFetchLeaderLogEndOffset, 
replicaState.lastFetchLeaderLogEndOffset,
+      "Unexpected Last Fetch Leader Log End Offset")
+    assertEquals(lastFetchTimeMs, replicaState.lastFetchTimeMs,
+      "Unexpected Last Fetch Time")
+  }
+
+  def assertReplicaStateDoesNotChange(
+    op: => Unit
+  ): Unit = {
+    val previousState = replica.stateSnapshot
+
+    op
+
+    assertReplicaState(
+      logStartOffset = previousState.logStartOffset,
+      logEndOffset = previousState.logEndOffset,
+      lastCaughtUpTimeMs = previousState.lastCaughtUpTimeMs,
+      lastFetchLeaderLogEndOffset = previousState.lastFetchLeaderLogEndOffset,
+      lastFetchTimeMs = previousState.lastFetchTimeMs
+    )
+  }
+
+  private def updateFetchState(
+    followerFetchOffset: Long,
+    followerStartOffset: Long,
+    leaderEndOffset: Long
+  ): Long = {
+    val currentTimeMs = time.milliseconds()
+    replica.updateFetchState(
+      followerFetchOffsetMetadata = LogOffsetMetadata(followerFetchOffset),
+      followerStartOffset = followerStartOffset,
+      followerFetchTimeMs = currentTimeMs,
+      leaderEndOffset = leaderEndOffset
+    )
+    currentTimeMs
+  }
+
+  private def resetReplicaState(
+    leaderEndOffset: Long,
+    isNewLeader: Boolean,
+    isFollowerInSync: Boolean
+  ): Long = {
+    val currentTimeMs = time.milliseconds()
+    replica.resetReplicaState(
+      currentTimeMs = currentTimeMs,
+      leaderEndOffset = leaderEndOffset,
+      isNewLeader = isNewLeader,
+      isFollowerInSync = isFollowerInSync
+    )
+    currentTimeMs
+  }
+
+  private def isCaughtUp(
+    leaderEndOffset: Long
+  ): Boolean = {
+    replica.stateSnapshot.isCaughtUp(
+      leaderEndOffset = leaderEndOffset,
+      currentTimeMs = time.milliseconds(),
+      replicaMaxLagMs = ReplicaLagTimeMaxMs
+    )
+  }
+
+  @Test
+  def testInitialState(): Unit = {
+    assertReplicaState(
+      logStartOffset = UnifiedLog.UnknownOffset,
+      logEndOffset = UnifiedLog.UnknownOffset,
+      lastCaughtUpTimeMs = 0L,
+      lastFetchLeaderLogEndOffset = 0L,
+      lastFetchTimeMs = 0L
+    )
+  }
+
+  @Test
+  def testUpdateFetchState(): Unit = {
+    val fetchTimeMs1 = updateFetchState(
+      followerFetchOffset = 5L,
+      followerStartOffset = 1L,
+      leaderEndOffset = 10L
+    )
+
+    assertReplicaState(
+      logStartOffset = 1L,
+      logEndOffset = 5L,
+      lastCaughtUpTimeMs = 0L,
+      lastFetchLeaderLogEndOffset = 10L,
+      lastFetchTimeMs = fetchTimeMs1
+    )
+
+    val fetchTimeMs2 = updateFetchState(
+      followerFetchOffset = 10L,
+      followerStartOffset = 2L,
+      leaderEndOffset = 15L
+    )
+
+    assertReplicaState(
+      logStartOffset = 2L,
+      logEndOffset = 10L,
+      lastCaughtUpTimeMs = fetchTimeMs1,
+      lastFetchLeaderLogEndOffset = 15L,
+      lastFetchTimeMs = fetchTimeMs2
+    )
+
+    val fetchTimeMs3 = updateFetchState(
+      followerFetchOffset = 15L,
+      followerStartOffset = 3L,
+      leaderEndOffset = 15L
+    )
+
+    assertReplicaState(
+      logStartOffset = 3L,
+      logEndOffset = 15L,
+      lastCaughtUpTimeMs = fetchTimeMs3,
+      lastFetchLeaderLogEndOffset = 15L,
+      lastFetchTimeMs = fetchTimeMs3
+    )
+  }
+
+  @Test
+  def testResetReplicaStateWhenLeaderIsReelectedAndReplicaIsInSync(): Unit = {
+    updateFetchState(
+      followerFetchOffset = 10L,
+      followerStartOffset = 1L,
+      leaderEndOffset = 10L
+    )
+
+    val resetTimeMs1 = resetReplicaState(
+      leaderEndOffset = 11L,
+      isNewLeader = false,
+      isFollowerInSync = true
+    )
+
+    assertReplicaState(
+      logStartOffset = 1L,
+      logEndOffset = 10L,
+      lastCaughtUpTimeMs = resetTimeMs1,
+      lastFetchLeaderLogEndOffset = 11L,
+      lastFetchTimeMs = resetTimeMs1
     )
   }
 
-  @AfterEach
-  def tearDown(): Unit = {
-    log.close()
-    brokerTopicStats.close()
-    Utils.delete(tmpDir)
+  @Test
+  def testResetReplicaStateWhenLeaderIsReelectedAndReplicaIsNotInSync(): Unit 
= {
+    updateFetchState(
+      followerFetchOffset = 10L,
+      followerStartOffset = 1L,
+      leaderEndOffset = 10L
+    )
+
+    resetReplicaState(
+      leaderEndOffset = 11L,
+      isNewLeader = false,
+      isFollowerInSync = false
+    )
+
+    assertReplicaState(
+      logStartOffset = 1L,
+      logEndOffset = 10L,
+      lastCaughtUpTimeMs = 0L,
+      lastFetchLeaderLogEndOffset = 11L,
+      lastFetchTimeMs = 0L
+    )
   }
 
   @Test
-  def testSegmentDeletionWithHighWatermarkInitialization(): Unit = {
-    val expiredTimestamp = time.milliseconds() - 1000
-    for (i <- 0 until 100) {
-      val records = TestUtils.singletonRecords(value = s"test$i".getBytes, 
timestamp = expiredTimestamp)
-      log.appendAsLeader(records, leaderEpoch = 0)
-    }
-
-    val initialHighWatermark = log.updateHighWatermark(25L)
-    assertEquals(25L, initialHighWatermark)
-
-    val initialNumSegments = log.numberOfSegments
-    log.deleteOldSegments()
-    assertTrue(log.numberOfSegments < initialNumSegments)
-    assertTrue(log.logStartOffset <= initialHighWatermark)
+  def testResetReplicaStateWhenNewLeaderIsElectedAndReplicaIsInSync(): Unit = {
+    updateFetchState(
+      followerFetchOffset = 10L,
+      followerStartOffset = 1L,
+      leaderEndOffset = 10L
+    )
+
+    val resetTimeMs1 = resetReplicaState(
+      leaderEndOffset = 11L,
+      isNewLeader = true,
+      isFollowerInSync = true
+    )
+
+    assertReplicaState(
+      logStartOffset = UnifiedLog.UnknownOffset,
+      logEndOffset = UnifiedLog.UnknownOffset,
+      lastCaughtUpTimeMs = resetTimeMs1,
+      lastFetchLeaderLogEndOffset = UnifiedLog.UnknownOffset,
+      lastFetchTimeMs = 0L
+    )
   }
 
   @Test
-  def testCannotDeleteSegmentsAtOrAboveHighWatermark(): Unit = {
-    val expiredTimestamp = time.milliseconds() - 1000
-    for (i <- 0 until 100) {
-      val records = TestUtils.singletonRecords(value = s"test$i".getBytes, 
timestamp = expiredTimestamp)
-      log.appendAsLeader(records, leaderEpoch = 0)
-    }
-
-    // ensure we have at least a few segments so the test case is not trivial
-    assertTrue(log.numberOfSegments > 5)
-    assertEquals(0L, log.highWatermark)
-    assertEquals(0L, log.logStartOffset)
-    assertEquals(100L, log.logEndOffset)
-
-    for (hw <- 0 to 100) {
-      log.updateHighWatermark(hw)
-      assertEquals(hw, log.highWatermark)
-      log.deleteOldSegments()
-      assertTrue(log.logStartOffset <= hw)
-
-      // verify that all segments up to the high watermark have been deleted
-
-      log.logSegments.headOption.foreach { segment =>
-        assertTrue(segment.baseOffset <= hw)
-        assertTrue(segment.baseOffset >= log.logStartOffset)
-      }
-      log.logSegments.tail.foreach { segment =>
-        assertTrue(segment.baseOffset > hw)
-        assertTrue(segment.baseOffset >= log.logStartOffset)
-      }
-    }
-
-    assertEquals(100L, log.logStartOffset)
-    assertEquals(1, log.numberOfSegments)
-    assertEquals(0, log.activeSegment.size)
+  def testResetReplicaStateWhenNewLeaderIsElectedAndReplicaIsNotInSync(): Unit 
= {
+    updateFetchState(
+      followerFetchOffset = 10L,
+      followerStartOffset = 1L,
+      leaderEndOffset = 10L
+    )
+
+    resetReplicaState(
+      leaderEndOffset = 11L,
+      isNewLeader = true,
+      isFollowerInSync = false
+    )
+
+    assertReplicaState(
+      logStartOffset = UnifiedLog.UnknownOffset,
+      logEndOffset = UnifiedLog.UnknownOffset,
+      lastCaughtUpTimeMs = 0L,
+      lastFetchLeaderLogEndOffset = UnifiedLog.UnknownOffset,
+      lastFetchTimeMs = 0L
+    )
+  }
+
+  @Test
+  def testIsCaughtUpWhenReplicaIsCaughtUpToLogEnd(): Unit = {
+    assertFalse(isCaughtUp(leaderEndOffset = 10L))
+
+    updateFetchState(
+      followerFetchOffset = 10L,
+      followerStartOffset = 1L,
+      leaderEndOffset = 10L
+    )
+
+    assertTrue(isCaughtUp(leaderEndOffset = 10L))
+
+    time.sleep(ReplicaLagTimeMaxMs + 1)
+
+    assertTrue(isCaughtUp(leaderEndOffset = 10L))
   }
 
   @Test
-  def testCannotIncrementLogStartOffsetPastHighWatermark(): Unit = {
-    for (i <- 0 until 100) {
-      val records = TestUtils.singletonRecords(value = s"test$i".getBytes)
-      log.appendAsLeader(records, leaderEpoch = 0)
-    }
-
-    log.updateHighWatermark(25L)
-    assertThrows(classOf[OffsetOutOfRangeException], () => 
log.maybeIncrementLogStartOffset(26L, ClientRecordDeletion))
+  def testIsCaughtUpWhenReplicaIsNotCaughtUpToLogEnd(): Unit = {
+    assertFalse(isCaughtUp(leaderEndOffset = 10L))
+
+    updateFetchState(
+      followerFetchOffset = 5L,
+      followerStartOffset = 1L,
+      leaderEndOffset = 10L
+    )
+
+    assertFalse(isCaughtUp(leaderEndOffset = 10L))
+
+    updateFetchState(
+      followerFetchOffset = 10L,
+      followerStartOffset = 1L,
+      leaderEndOffset = 15L
+    )
+
+    assertTrue(isCaughtUp(leaderEndOffset = 16L))
+
+    time.sleep(ReplicaLagTimeMaxMs + 1)
+
+    assertFalse(isCaughtUp(leaderEndOffset = 16L))
   }
 }
diff --git a/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala 
b/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
index e8c2c28e76..e2da713809 100755
--- a/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
+++ b/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
@@ -3325,6 +3325,91 @@ class UnifiedLogTest {
     assertEquals(1, log.numberOfSegments)
   }
 
+  @Test
+  def testSegmentDeletionWithHighWatermarkInitialization(): Unit = {
+    val logConfig = LogTestUtils.createLogConfig(
+      segmentBytes = 512,
+      segmentIndexBytes = 1000,
+      retentionMs = 999
+    )
+    val log = createLog(logDir, logConfig)
+
+    val expiredTimestamp = mockTime.milliseconds() - 1000
+    for (i <- 0 until 100) {
+      val records = TestUtils.singletonRecords(value = s"test$i".getBytes, 
timestamp = expiredTimestamp)
+      log.appendAsLeader(records, leaderEpoch = 0)
+    }
+
+    val initialHighWatermark = log.updateHighWatermark(25L)
+    assertEquals(25L, initialHighWatermark)
+
+    val initialNumSegments = log.numberOfSegments
+    log.deleteOldSegments()
+    assertTrue(log.numberOfSegments < initialNumSegments)
+    assertTrue(log.logStartOffset <= initialHighWatermark)
+  }
+
+  @Test
+  def testCannotDeleteSegmentsAtOrAboveHighWatermark(): Unit = {
+    val logConfig = LogTestUtils.createLogConfig(
+      segmentBytes = 512,
+      segmentIndexBytes = 1000,
+      retentionMs = 999
+    )
+    val log = createLog(logDir, logConfig)
+
+    val expiredTimestamp = mockTime.milliseconds() - 1000
+    for (i <- 0 until 100) {
+      val records = TestUtils.singletonRecords(value = s"test$i".getBytes, 
timestamp = expiredTimestamp)
+      log.appendAsLeader(records, leaderEpoch = 0)
+    }
+
+    // ensure we have at least a few segments so the test case is not trivial
+    assertTrue(log.numberOfSegments > 5)
+    assertEquals(0L, log.highWatermark)
+    assertEquals(0L, log.logStartOffset)
+    assertEquals(100L, log.logEndOffset)
+
+    for (hw <- 0 to 100) {
+      log.updateHighWatermark(hw)
+      assertEquals(hw, log.highWatermark)
+      log.deleteOldSegments()
+      assertTrue(log.logStartOffset <= hw)
+
+      // verify that all segments up to the high watermark have been deleted
+      log.logSegments.headOption.foreach { segment =>
+        assertTrue(segment.baseOffset <= hw)
+        assertTrue(segment.baseOffset >= log.logStartOffset)
+      }
+      log.logSegments.tail.foreach { segment =>
+        assertTrue(segment.baseOffset > hw)
+        assertTrue(segment.baseOffset >= log.logStartOffset)
+      }
+    }
+
+    assertEquals(100L, log.logStartOffset)
+    assertEquals(1, log.numberOfSegments)
+    assertEquals(0, log.activeSegment.size)
+  }
+
+  @Test
+  def testCannotIncrementLogStartOffsetPastHighWatermark(): Unit = {
+    val logConfig = LogTestUtils.createLogConfig(
+      segmentBytes = 512,
+      segmentIndexBytes = 1000,
+      retentionMs = 999
+    )
+    val log = createLog(logDir, logConfig)
+
+    for (i <- 0 until 100) {
+      val records = TestUtils.singletonRecords(value = s"test$i".getBytes)
+      log.appendAsLeader(records, leaderEpoch = 0)
+    }
+
+    log.updateHighWatermark(25L)
+    assertThrows(classOf[OffsetOutOfRangeException], () => 
log.maybeIncrementLogStartOffset(26L, ClientRecordDeletion))
+  }
+
   private def appendTransactionalToBuffer(buffer: ByteBuffer,
                                           producerId: Long,
                                           producerEpoch: Short,
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala 
b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
index 96299bf2fd..6d01f59259 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
@@ -740,8 +740,8 @@ class ReplicaManagerTest {
 
       assertTrue(partition.getReplica(1).isDefined)
       val followerReplica = partition.getReplica(1).get
-      assertEquals(-1L, followerReplica.logStartOffset)
-      assertEquals(-1L, followerReplica.logEndOffset)
+      assertEquals(-1L, followerReplica.stateSnapshot.logStartOffset)
+      assertEquals(-1L, followerReplica.stateSnapshot.logEndOffset)
 
       // Leader appends some data
       for (i <- 1 to 5) {
@@ -773,8 +773,8 @@ class ReplicaManagerTest {
       )
 
       assertTrue(successfulFetch.isDefined)
-      assertEquals(0L, followerReplica.logStartOffset)
-      assertEquals(0L, followerReplica.logEndOffset)
+      assertEquals(0L, followerReplica.stateSnapshot.logStartOffset)
+      assertEquals(0L, followerReplica.stateSnapshot.logEndOffset)
 
 
       // Next we receive an invalid request with a higher fetch offset, but an 
old epoch.
@@ -796,8 +796,8 @@ class ReplicaManagerTest {
       )
 
       assertTrue(successfulFetch.isDefined)
-      assertEquals(0L, followerReplica.logStartOffset)
-      assertEquals(0L, followerReplica.logEndOffset)
+      assertEquals(0L, followerReplica.stateSnapshot.logStartOffset)
+      assertEquals(0L, followerReplica.stateSnapshot.logEndOffset)
 
       // Next we receive an invalid request with a higher fetch offset, but a 
diverging epoch.
       // We expect that the replica state does not get updated.
@@ -818,8 +818,8 @@ class ReplicaManagerTest {
       )
 
       assertTrue(successfulFetch.isDefined)
-      assertEquals(0L, followerReplica.logStartOffset)
-      assertEquals(0L, followerReplica.logEndOffset)
+      assertEquals(0L, followerReplica.stateSnapshot.logStartOffset)
+      assertEquals(0L, followerReplica.stateSnapshot.logEndOffset)
 
     } finally {
       replicaManager.shutdown(checkpointHW = false)

Reply via email to