This is an automated email from the ASF dual-hosted git repository.

showuon pushed a commit to branch 3.3
in repository https://gitbox.apache.org/repos/asf/kafka.git

commit 158fc5f0269670d80d6c8f49186c51930bd600a1
Author: Luke Chen <show...@gmail.com>
AuthorDate: Fri Jul 22 11:00:15 2022 +0800

    KAFKA-13919: expose log recovery metrics (#12347)
    
    Implementation for KIP-831.
    1. add remainingLogsToRecover metric for the number of remaining logs for 
each log.dir to be recovered
    2.  add remainingSegmentsToRecover metric for the number of remaining 
segments for the current log assigned to the recovery thread.
    3. remove these metrics after log loaded completely
    4. add tests
    
    Reviewers: Jun Rao <j...@confluent.io>, Tom Bentley <tbent...@redhat.com>
---
 core/src/main/scala/kafka/log/LogLoader.scala      |  26 ++-
 core/src/main/scala/kafka/log/LogManager.scala     |  85 ++++++--
 core/src/main/scala/kafka/log/UnifiedLog.scala     |  10 +-
 .../test/scala/unit/kafka/log/LogLoaderTest.scala  |   3 +-
 .../test/scala/unit/kafka/log/LogManagerTest.scala | 220 ++++++++++++++++++++-
 .../test/scala/unit/kafka/log/LogTestUtils.scala   |   7 +-
 .../test/scala/unit/kafka/utils/TestUtils.scala    |   5 +-
 .../apache/kafka/jmh/server/CheckpointBench.java   |   2 +-
 8 files changed, 319 insertions(+), 39 deletions(-)

diff --git a/core/src/main/scala/kafka/log/LogLoader.scala 
b/core/src/main/scala/kafka/log/LogLoader.scala
index 581d016e5e..25ee89c72b 100644
--- a/core/src/main/scala/kafka/log/LogLoader.scala
+++ b/core/src/main/scala/kafka/log/LogLoader.scala
@@ -29,6 +29,7 @@ import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.errors.InvalidOffsetException
 import org.apache.kafka.common.utils.Time
 
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
 import scala.collection.{Set, mutable}
 
 case class LoadedLogOffsets(logStartOffset: Long,
@@ -64,6 +65,7 @@ object LogLoader extends Logging {
  * @param recoveryPointCheckpoint The checkpoint of the offset at which to 
begin the recovery
  * @param leaderEpochCache An optional LeaderEpochFileCache instance to be 
updated during recovery
  * @param producerStateManager The ProducerStateManager instance to be updated 
during recovery
+ * @param numRemainingSegments The remaining segments to be recovered in this 
log keyed by recovery thread name
  */
 class LogLoader(
   dir: File,
@@ -77,7 +79,8 @@ class LogLoader(
   logStartOffsetCheckpoint: Long,
   recoveryPointCheckpoint: Long,
   leaderEpochCache: Option[LeaderEpochFileCache],
-  producerStateManager: ProducerStateManager
+  producerStateManager: ProducerStateManager,
+  numRemainingSegments: ConcurrentMap[String, Int] = new 
ConcurrentHashMap[String, Int]
 ) extends Logging {
   logIdent = s"[LogLoader partition=$topicPartition, dir=${dir.getParent}] "
 
@@ -404,12 +407,18 @@ class LogLoader(
 
     // If we have the clean shutdown marker, skip recovery.
     if (!hadCleanShutdown) {
-      val unflushed = segments.values(recoveryPointCheckpoint, 
Long.MaxValue).iterator
+      val unflushed = segments.values(recoveryPointCheckpoint, Long.MaxValue)
+      val numUnflushed = unflushed.size
+      val unflushedIter = unflushed.iterator
       var truncated = false
+      var numFlushed = 0
+      val threadName = Thread.currentThread().getName
+      numRemainingSegments.put(threadName, numUnflushed)
+
+      while (unflushedIter.hasNext && !truncated) {
+        val segment = unflushedIter.next()
+        info(s"Recovering unflushed segment ${segment.baseOffset}. 
$numFlushed/$numUnflushed recovered for $topicPartition.")
 
-      while (unflushed.hasNext && !truncated) {
-        val segment = unflushed.next()
-        info(s"Recovering unflushed segment ${segment.baseOffset}")
         val truncatedBytes =
           try {
             recoverSegment(segment)
@@ -424,8 +433,13 @@ class LogLoader(
           // we had an invalid message, delete all remaining log
           warn(s"Corruption found in segment ${segment.baseOffset}," +
             s" truncating to offset ${segment.readNextOffset}")
-          removeAndDeleteSegmentsAsync(unflushed.toList)
+          removeAndDeleteSegmentsAsync(unflushedIter.toList)
           truncated = true
+          // segment is truncated, so set remaining segments to 0
+          numRemainingSegments.put(threadName, 0)
+        } else {
+          numFlushed += 1
+          numRemainingSegments.put(threadName, numUnflushed - numFlushed)
         }
       }
     }
diff --git a/core/src/main/scala/kafka/log/LogManager.scala 
b/core/src/main/scala/kafka/log/LogManager.scala
index bdc7ffd74d..886f56c63c 100755
--- a/core/src/main/scala/kafka/log/LogManager.scala
+++ b/core/src/main/scala/kafka/log/LogManager.scala
@@ -262,7 +262,8 @@ class LogManager(logDirs: Seq[File],
                            recoveryPoints: Map[TopicPartition, Long],
                            logStartOffsets: Map[TopicPartition, Long],
                            defaultConfig: LogConfig,
-                           topicConfigOverrides: Map[String, LogConfig]): 
UnifiedLog = {
+                           topicConfigOverrides: Map[String, LogConfig],
+                           numRemainingSegments: ConcurrentMap[String, Int]): 
UnifiedLog = {
     val topicPartition = UnifiedLog.parseTopicPartitionName(logDir)
     val config = topicConfigOverrides.getOrElse(topicPartition.topic, 
defaultConfig)
     val logRecoveryPoint = recoveryPoints.getOrElse(topicPartition, 0L)
@@ -282,7 +283,8 @@ class LogManager(logDirs: Seq[File],
       logDirFailureChannel = logDirFailureChannel,
       lastShutdownClean = hadCleanShutdown,
       topicId = None,
-      keepPartitionMetadataFile = keepPartitionMetadataFile)
+      keepPartitionMetadataFile = keepPartitionMetadataFile,
+      numRemainingSegments = numRemainingSegments)
 
     if (logDir.getName.endsWith(UnifiedLog.DeleteDirSuffix)) {
       addLogToBeDeleted(log)
@@ -307,6 +309,27 @@ class LogManager(logDirs: Seq[File],
     log
   }
 
+  // factory class for naming the log recovery threads used in metrics
+  class LogRecoveryThreadFactory(val dirPath: String) extends ThreadFactory {
+    val threadNum = new AtomicInteger(0)
+
+    override def newThread(runnable: Runnable): Thread = {
+      KafkaThread.nonDaemon(logRecoveryThreadName(dirPath, 
threadNum.getAndIncrement()), runnable)
+    }
+  }
+
+  // create a unique log recovery thread name for each log dir as the format: 
prefix-dirPath-threadNum, ex: "log-recovery-/tmp/kafkaLogs-0"
+  private def logRecoveryThreadName(dirPath: String, threadNum: Int, prefix: 
String = "log-recovery"): String = s"$prefix-$dirPath-$threadNum"
+
+  /*
+   * decrement the number of remaining logs
+   * @return the number of remaining logs after decremented 1
+   */
+  private[log] def decNumRemainingLogs(numRemainingLogs: ConcurrentMap[String, 
Int], path: String): Int = {
+    require(path != null, "path cannot be null to update remaining logs 
metric.")
+    numRemainingLogs.compute(path, (_, oldVal) => oldVal - 1)
+  }
+
   /**
    * Recover and load all logs in the given data directories
    */
@@ -317,6 +340,10 @@ class LogManager(logDirs: Seq[File],
     val offlineDirs = mutable.Set.empty[(String, IOException)]
     val jobs = ArrayBuffer.empty[Seq[Future[_]]]
     var numTotalLogs = 0
+    // log dir path -> number of Remaining logs map for remainingLogsToRecover 
metric
+    val numRemainingLogs: ConcurrentMap[String, Int] = new 
ConcurrentHashMap[String, Int]
+    // log recovery thread name -> number of remaining segments map for 
remainingSegmentsToRecover metric
+    val numRemainingSegments: ConcurrentMap[String, Int] = new 
ConcurrentHashMap[String, Int]
 
     def handleIOException(logDirAbsolutePath: String, e: IOException): Unit = {
       offlineDirs.add((logDirAbsolutePath, e))
@@ -328,7 +355,7 @@ class LogManager(logDirs: Seq[File],
       var hadCleanShutdown: Boolean = false
       try {
         val pool = Executors.newFixedThreadPool(numRecoveryThreadsPerDataDir,
-          KafkaThread.nonDaemon(s"log-recovery-$logDirAbsolutePath", _))
+          new LogRecoveryThreadFactory(logDirAbsolutePath))
         threadPools.append(pool)
 
         val cleanShutdownFile = new File(dir, LogLoader.CleanShutdownFile)
@@ -363,28 +390,32 @@ class LogManager(logDirs: Seq[File],
 
         val logsToLoad = 
Option(dir.listFiles).getOrElse(Array.empty).filter(logDir =>
           logDir.isDirectory && 
UnifiedLog.parseTopicPartitionName(logDir).topic != 
KafkaRaftServer.MetadataTopic)
-        val numLogsLoaded = new AtomicInteger(0)
         numTotalLogs += logsToLoad.length
+        numRemainingLogs.put(dir.getAbsolutePath, logsToLoad.length)
 
         val jobsForDir = logsToLoad.map { logDir =>
           val runnable: Runnable = () => {
+            debug(s"Loading log $logDir")
+            var log = None: Option[UnifiedLog]
+            val logLoadStartMs = time.hiResClockMs()
             try {
-              debug(s"Loading log $logDir")
-
-              val logLoadStartMs = time.hiResClockMs()
-              val log = loadLog(logDir, hadCleanShutdown, recoveryPoints, 
logStartOffsets,
-                defaultConfig, topicConfigOverrides)
-              val logLoadDurationMs = time.hiResClockMs() - logLoadStartMs
-              val currentNumLoaded = numLogsLoaded.incrementAndGet()
-
-              info(s"Completed load of $log with ${log.numberOfSegments} 
segments in ${logLoadDurationMs}ms " +
-                s"($currentNumLoaded/${logsToLoad.length} loaded in 
$logDirAbsolutePath)")
+              log = Some(loadLog(logDir, hadCleanShutdown, recoveryPoints, 
logStartOffsets,
+                defaultConfig, topicConfigOverrides, numRemainingSegments))
             } catch {
               case e: IOException =>
                 handleIOException(logDirAbsolutePath, e)
               case e: KafkaStorageException if 
e.getCause.isInstanceOf[IOException] =>
                 // KafkaStorageException might be thrown, ex: during writing 
LeaderEpochFileCache
                 // And while converting IOException to KafkaStorageException, 
we've already handled the exception. So we can ignore it here.
+            } finally {
+              val logLoadDurationMs = time.hiResClockMs() - logLoadStartMs
+              val remainingLogs = decNumRemainingLogs(numRemainingLogs, 
dir.getAbsolutePath)
+              val currentNumLoaded = logsToLoad.length - remainingLogs
+              log match {
+                case Some(loadedLog) => info(s"Completed load of $loadedLog 
with ${loadedLog.numberOfSegments} segments in ${logLoadDurationMs}ms " +
+                  s"($currentNumLoaded/${logsToLoad.length} completed in 
$logDirAbsolutePath)")
+                case None => info(s"Error while loading logs in $logDir in 
${logLoadDurationMs}ms ($currentNumLoaded/${logsToLoad.length} completed in 
$logDirAbsolutePath)")
+              }
             }
           }
           runnable
@@ -398,6 +429,7 @@ class LogManager(logDirs: Seq[File],
     }
 
     try {
+      addLogRecoveryMetrics(numRemainingLogs, numRemainingSegments)
       for (dirJobs <- jobs) {
         dirJobs.foreach(_.get)
       }
@@ -410,12 +442,37 @@ class LogManager(logDirs: Seq[File],
         error(s"There was an error in one of the threads during logs loading: 
${e.getCause}")
         throw e.getCause
     } finally {
+      removeLogRecoveryMetrics()
       threadPools.foreach(_.shutdown())
     }
 
     info(s"Loaded $numTotalLogs logs in ${time.hiResClockMs() - startMs}ms.")
   }
 
+  private[log] def addLogRecoveryMetrics(numRemainingLogs: 
ConcurrentMap[String, Int],
+                                         numRemainingSegments: 
ConcurrentMap[String, Int]): Unit = {
+    debug("Adding log recovery metrics")
+    for (dir <- logDirs) {
+      newGauge("remainingLogsToRecover", () => 
numRemainingLogs.get(dir.getAbsolutePath),
+        Map("dir" -> dir.getAbsolutePath))
+      for (i <- 0 until numRecoveryThreadsPerDataDir) {
+        val threadName = logRecoveryThreadName(dir.getAbsolutePath, i)
+        newGauge("remainingSegmentsToRecover", () => 
numRemainingSegments.get(threadName),
+          Map("dir" -> dir.getAbsolutePath, "threadNum" -> i.toString))
+      }
+    }
+  }
+
+  private[log] def removeLogRecoveryMetrics(): Unit = {
+    debug("Removing log recovery metrics")
+    for (dir <- logDirs) {
+      removeMetric("remainingLogsToRecover", Map("dir" -> dir.getAbsolutePath))
+      for (i <- 0 until numRecoveryThreadsPerDataDir) {
+        removeMetric("remainingSegmentsToRecover", Map("dir" -> 
dir.getAbsolutePath, "threadNum" -> i.toString))
+      }
+    }
+  }
+
   /**
    *  Start the background threads to flush logs and do log cleanup
    */
diff --git a/core/src/main/scala/kafka/log/UnifiedLog.scala 
b/core/src/main/scala/kafka/log/UnifiedLog.scala
index ddd66eb160..c4a2300237 100644
--- a/core/src/main/scala/kafka/log/UnifiedLog.scala
+++ b/core/src/main/scala/kafka/log/UnifiedLog.scala
@@ -18,11 +18,11 @@
 package kafka.log
 
 import com.yammer.metrics.core.MetricName
+
 import java.io.{File, IOException}
 import java.nio.file.Files
 import java.util.Optional
-import java.util.concurrent.TimeUnit
-
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, TimeUnit}
 import kafka.common.{LongRef, OffsetsOutOfOrderException, 
UnexpectedAppendOffsetException}
 import kafka.log.AppendOrigin.RaftLeader
 import kafka.message.{BrokerCompressionCodec, CompressionCodec, 
NoCompressionCodec}
@@ -1803,7 +1803,8 @@ object UnifiedLog extends Logging {
             logDirFailureChannel: LogDirFailureChannel,
             lastShutdownClean: Boolean = true,
             topicId: Option[Uuid],
-            keepPartitionMetadataFile: Boolean): UnifiedLog = {
+            keepPartitionMetadataFile: Boolean,
+            numRemainingSegments: ConcurrentMap[String, Int] = new 
ConcurrentHashMap[String, Int]): UnifiedLog = {
     // create the log directory if it doesn't exist
     Files.createDirectories(dir.toPath)
     val topicPartition = UnifiedLog.parseTopicPartitionName(dir)
@@ -1828,7 +1829,8 @@ object UnifiedLog extends Logging {
       logStartOffset,
       recoveryPoint,
       leaderEpochCache,
-      producerStateManager
+      producerStateManager,
+      numRemainingSegments
     ).load()
     val localLog = new LocalLog(dir, config, segments, offsets.recoveryPoint,
       offsets.nextOffsetMetadata, scheduler, time, topicPartition, 
logDirFailureChannel)
diff --git a/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala 
b/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
index 0d41a5073b..c6379ff3f3 100644
--- a/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
@@ -38,6 +38,7 @@ import org.mockito.ArgumentMatchers
 import org.mockito.ArgumentMatchers.{any, anyLong}
 import org.mockito.Mockito.{mock, reset, times, verify, when}
 
+import java.util.concurrent.ConcurrentMap
 import scala.annotation.nowarn
 import scala.collection.mutable.ListBuffer
 import scala.collection.{Iterable, Map, mutable}
@@ -117,7 +118,7 @@ class LogLoaderTest {
 
         override def loadLog(logDir: File, hadCleanShutdown: Boolean, 
recoveryPoints: Map[TopicPartition, Long],
                              logStartOffsets: Map[TopicPartition, Long], 
defaultConfig: LogConfig,
-                             topicConfigs: Map[String, LogConfig]): UnifiedLog 
= {
+                             topicConfigs: Map[String, LogConfig], 
numRemainingSegments: ConcurrentMap[String, Int]): UnifiedLog = {
           if (simulateError.hasError) {
             simulateError.errorType match {
               case ErrorTypes.KafkaStorageExceptionWithIOExceptionCause =>
diff --git a/core/src/test/scala/unit/kafka/log/LogManagerTest.scala 
b/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
index 5353df6db3..1b2dd7809f 100755
--- a/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
@@ -17,10 +17,10 @@
 
 package kafka.log
 
-import com.yammer.metrics.core.MetricName
+import com.yammer.metrics.core.{Gauge, MetricName}
 import kafka.server.checkpoints.OffsetCheckpointFile
 import kafka.server.metadata.{ConfigRepository, MockConfigRepository}
-import kafka.server.{FetchDataInfo, FetchLogEnd}
+import kafka.server.{BrokerTopicStats, FetchDataInfo, FetchLogEnd, 
LogDirFailureChannel}
 import kafka.utils._
 import org.apache.directory.api.util.FileUtils
 import org.apache.kafka.common.errors.OffsetOutOfRangeException
@@ -29,16 +29,17 @@ import org.apache.kafka.common.{KafkaException, 
TopicPartition}
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
 import org.mockito.ArgumentMatchers.any
-import org.mockito.{ArgumentMatchers, Mockito}
-import org.mockito.Mockito.{doAnswer, mock, never, spy, times, verify}
+import org.mockito.{ArgumentCaptor, ArgumentMatchers, Mockito}
+import org.mockito.Mockito.{doAnswer, doNothing, mock, never, spy, times, 
verify}
+
 import java.io._
 import java.nio.file.Files
-import java.util.concurrent.Future
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Future}
 import java.util.{Collections, Properties}
-
 import org.apache.kafka.server.metrics.KafkaYammerMetrics
 
-import scala.collection.mutable
+import scala.collection.{Map, mutable}
+import scala.collection.mutable.ArrayBuffer
 import scala.jdk.CollectionConverters._
 import scala.util.{Failure, Try}
 
@@ -421,12 +422,14 @@ class LogManagerTest {
   }
 
   private def createLogManager(logDirs: Seq[File] = Seq(this.logDir),
-                               configRepository: ConfigRepository = new 
MockConfigRepository): LogManager = {
+                               configRepository: ConfigRepository = new 
MockConfigRepository,
+                               recoveryThreadsPerDataDir: Int = 1): LogManager 
= {
     TestUtils.createLogManager(
       defaultConfig = logConfig,
       configRepository = configRepository,
       logDirs = logDirs,
-      time = this.time)
+      time = this.time,
+      recoveryThreadsPerDataDir = recoveryThreadsPerDataDir)
   }
 
   @Test
@@ -638,6 +641,205 @@ class LogManagerTest {
     assertTrue(logManager.partitionsInitializing.isEmpty)
   }
 
+  private def appendRecordsToLog(time: MockTime, parentLogDir: File, 
partitionId: Int, brokerTopicStats: BrokerTopicStats, expectedSegmentsPerLog: 
Int): Unit = {
+    def createRecord = TestUtils.singletonRecords(value = "test".getBytes, 
timestamp = time.milliseconds)
+    val tpFile = new File(parentLogDir, s"$name-$partitionId")
+    val segmentBytes = 1024
+
+    val log = LogTestUtils.createLog(tpFile, logConfig, brokerTopicStats, 
time.scheduler, time, 0, 0,
+      5 * 60 * 1000, 60 * 60 * 1000, 
LogManager.ProducerIdExpirationCheckIntervalMs)
+
+    assertTrue(expectedSegmentsPerLog > 0)
+    // calculate numMessages to append to logs. It'll create 
"expectedSegmentsPerLog" log segments with segment.bytes=1024
+    val numMessages = Math.floor(segmentBytes * expectedSegmentsPerLog / 
createRecord.sizeInBytes).asInstanceOf[Int]
+    try {
+      for (_ <- 0 until numMessages) {
+        log.appendAsLeader(createRecord, leaderEpoch = 0)
+      }
+
+      assertEquals(expectedSegmentsPerLog, log.numberOfSegments)
+    } finally {
+      log.close()
+    }
+  }
+
+  private def verifyRemainingLogsToRecoverMetric(spyLogManager: LogManager, 
expectedParams: Map[String, Int]): Unit = {
+    val spyLogManagerClassName = spyLogManager.getClass().getSimpleName
+    // get all `remainingLogsToRecover` metrics
+    val logMetrics: ArrayBuffer[Gauge[Int]] = 
KafkaYammerMetrics.defaultRegistry.allMetrics.asScala
+      .filter { case (metric, _) => metric.getType == 
s"$spyLogManagerClassName" && metric.getName == "remainingLogsToRecover" }
+      .map { case (_, gauge) => gauge }
+      .asInstanceOf[ArrayBuffer[Gauge[Int]]]
+
+    assertEquals(expectedParams.size, logMetrics.size)
+
+    val capturedPath: ArgumentCaptor[String] = 
ArgumentCaptor.forClass(classOf[String])
+
+    val expectedCallTimes = expectedParams.values.sum
+    verify(spyLogManager, 
times(expectedCallTimes)).decNumRemainingLogs(any[ConcurrentMap[String, Int]], 
capturedPath.capture());
+
+    val paths = capturedPath.getAllValues
+    expectedParams.foreach {
+      case (path, totalLogs) =>
+        // make sure each path is called "totalLogs" times, which means it is 
decremented to 0 in the end
+        assertEquals(totalLogs, Collections.frequency(paths, path))
+    }
+
+    // expected the end value is 0
+    logMetrics.foreach { gauge => assertEquals(0, gauge.value()) }
+  }
+
+  private def verifyRemainingSegmentsToRecoverMetric(spyLogManager: LogManager,
+                                                     logDirs: Seq[File],
+                                                     
recoveryThreadsPerDataDir: Int,
+                                                     mockMap: 
ConcurrentHashMap[String, Int],
+                                                     expectedParams: 
Map[String, Int]): Unit = {
+    val spyLogManagerClassName = spyLogManager.getClass().getSimpleName
+    // get all `remainingSegmentsToRecover` metrics
+    val logSegmentMetrics: ArrayBuffer[Gauge[Int]] = 
KafkaYammerMetrics.defaultRegistry.allMetrics.asScala
+          .filter { case (metric, _) => metric.getType == 
s"$spyLogManagerClassName" && metric.getName == "remainingSegmentsToRecover" }
+          .map { case (_, gauge) => gauge }
+          .asInstanceOf[ArrayBuffer[Gauge[Int]]]
+
+    // expected each log dir has 1 metrics for each thread
+    assertEquals(recoveryThreadsPerDataDir * logDirs.size, 
logSegmentMetrics.size)
+
+    val capturedThreadName: ArgumentCaptor[String] = 
ArgumentCaptor.forClass(classOf[String])
+    val capturedNumRemainingSegments: ArgumentCaptor[Int] = 
ArgumentCaptor.forClass(classOf[Int])
+
+    // Since we'll update numRemainingSegments from totalSegments to 0 for 
each thread, so we need to add 1 here
+    val expectedCallTimes = expectedParams.values.map( num => num + 1 ).sum
+    verify(mockMap, 
times(expectedCallTimes)).put(capturedThreadName.capture(), 
capturedNumRemainingSegments.capture());
+
+    // expected the end value is 0
+    logSegmentMetrics.foreach { gauge => assertEquals(0, gauge.value()) }
+
+    val threadNames = capturedThreadName.getAllValues
+    val numRemainingSegments = capturedNumRemainingSegments.getAllValues
+
+    expectedParams.foreach {
+      case (threadName, totalSegments) =>
+        // make sure we update the numRemainingSegments from totalSegments to 
0 in order for each thread
+        var expectedCurRemainingSegments = totalSegments + 1
+        for (i <- 0 until threadNames.size) {
+          if (threadNames.get(i).contains(threadName)) {
+            expectedCurRemainingSegments -= 1
+            assertEquals(expectedCurRemainingSegments, 
numRemainingSegments.get(i))
+          }
+        }
+        assertEquals(0, expectedCurRemainingSegments)
+    }
+  }
+
+  private def verifyLogRecoverMetricsRemoved(spyLogManager: LogManager): Unit 
= {
+    val spyLogManagerClassName = spyLogManager.getClass().getSimpleName
+    // get all `remainingLogsToRecover` metrics
+    def logMetrics: mutable.Set[MetricName] = 
KafkaYammerMetrics.defaultRegistry.allMetrics.keySet.asScala
+      .filter { metric => metric.getType == s"$spyLogManagerClassName" && 
metric.getName == "remainingLogsToRecover" }
+
+    assertTrue(logMetrics.isEmpty)
+
+    // get all `remainingSegmentsToRecover` metrics
+    val logSegmentMetrics: mutable.Set[MetricName] = 
KafkaYammerMetrics.defaultRegistry.allMetrics.keySet.asScala
+      .filter { metric => metric.getType == s"$spyLogManagerClassName" && 
metric.getName == "remainingSegmentsToRecover" }
+
+    assertTrue(logSegmentMetrics.isEmpty)
+  }
+
+  @Test
+  def testLogRecoveryMetrics(): Unit = {
+    logManager.shutdown()
+    val logDir1 = TestUtils.tempDir()
+    val logDir2 = TestUtils.tempDir()
+    val logDirs = Seq(logDir1, logDir2)
+    val recoveryThreadsPerDataDir = 2
+    // create logManager with expected recovery thread number
+    logManager = createLogManager(logDirs, recoveryThreadsPerDataDir = 
recoveryThreadsPerDataDir)
+    val spyLogManager = spy(logManager)
+
+    assertEquals(2, spyLogManager.liveLogDirs.size)
+
+    val mockTime = new MockTime()
+    val mockMap = mock(classOf[ConcurrentHashMap[String, Int]])
+    val mockBrokerTopicStats = mock(classOf[BrokerTopicStats])
+    val expectedSegmentsPerLog = 2
+
+    // create log segments for log recovery in each log dir
+    appendRecordsToLog(mockTime, logDir1, 0, mockBrokerTopicStats, 
expectedSegmentsPerLog)
+    appendRecordsToLog(mockTime, logDir2, 1, mockBrokerTopicStats, 
expectedSegmentsPerLog)
+
+    // intercept loadLog method to pass expected parameter to do log recovery
+    doAnswer { invocation =>
+      val dir: File = invocation.getArgument(0)
+      val topicConfigOverrides: mutable.Map[String, LogConfig] = 
invocation.getArgument(5)
+
+      val topicPartition = UnifiedLog.parseTopicPartitionName(dir)
+      val config = topicConfigOverrides.getOrElse(topicPartition.topic, 
logConfig)
+
+      UnifiedLog(
+        dir = dir,
+        config = config,
+        logStartOffset = 0,
+        recoveryPoint = 0,
+        maxTransactionTimeoutMs = 5 * 60 * 1000,
+        maxProducerIdExpirationMs = 5 * 60 * 1000,
+        producerIdExpirationCheckIntervalMs = 
LogManager.ProducerIdExpirationCheckIntervalMs,
+        scheduler = mockTime.scheduler,
+        time = mockTime,
+        brokerTopicStats = mockBrokerTopicStats,
+        logDirFailureChannel = mock(classOf[LogDirFailureChannel]),
+        // not clean shutdown
+        lastShutdownClean = false,
+        topicId = None,
+        keepPartitionMetadataFile = false,
+        // pass mock map for verification later
+        numRemainingSegments = mockMap)
+
+    }.when(spyLogManager).loadLog(any[File], any[Boolean], 
any[Map[TopicPartition, Long]], any[Map[TopicPartition, Long]],
+      any[LogConfig], any[Map[String, LogConfig]], any[ConcurrentMap[String, 
Int]])
+
+    // do nothing for removeLogRecoveryMetrics for metrics verification
+    doNothing().when(spyLogManager).removeLogRecoveryMetrics()
+
+    // start the logManager to do log recovery
+    spyLogManager.startup(Set.empty)
+
+    // make sure log recovery metrics are added and removed
+    verify(spyLogManager, 
times(1)).addLogRecoveryMetrics(any[ConcurrentMap[String, Int]], 
any[ConcurrentMap[String, Int]])
+    verify(spyLogManager, times(1)).removeLogRecoveryMetrics()
+
+    // expected 1 log in each log dir since we created 2 partitions with 2 log 
dirs
+    val expectedRemainingLogsParams = Map[String, Int](logDir1.getAbsolutePath 
-> 1, logDir2.getAbsolutePath -> 1)
+    verifyRemainingLogsToRecoverMetric(spyLogManager, 
expectedRemainingLogsParams)
+
+    val expectedRemainingSegmentsParams = Map[String, Int](
+      logDir1.getAbsolutePath -> expectedSegmentsPerLog, 
logDir2.getAbsolutePath -> expectedSegmentsPerLog)
+    verifyRemainingSegmentsToRecoverMetric(spyLogManager, logDirs, 
recoveryThreadsPerDataDir, mockMap, expectedRemainingSegmentsParams)
+  }
+
+  @Test
+  def testLogRecoveryMetricsShouldBeRemovedAfterLogRecovered(): Unit = {
+    logManager.shutdown()
+    val logDir1 = TestUtils.tempDir()
+    val logDir2 = TestUtils.tempDir()
+    val logDirs = Seq(logDir1, logDir2)
+    val recoveryThreadsPerDataDir = 2
+    // create logManager with expected recovery thread number
+    logManager = createLogManager(logDirs, recoveryThreadsPerDataDir = 
recoveryThreadsPerDataDir)
+    val spyLogManager = spy(logManager)
+
+    assertEquals(2, spyLogManager.liveLogDirs.size)
+
+    // start the logManager to do log recovery
+    spyLogManager.startup(Set.empty)
+
+    // make sure log recovery metrics are added and removed once
+    verify(spyLogManager, 
times(1)).addLogRecoveryMetrics(any[ConcurrentMap[String, Int]], 
any[ConcurrentMap[String, Int]])
+    verify(spyLogManager, times(1)).removeLogRecoveryMetrics()
+
+    verifyLogRecoverMetricsRemoved(spyLogManager)
+  }
+
   @Test
   def testMetricsExistWhenLogIsRecreatedBeforeDeletion(): Unit = {
     val topicName = "metric-test"
diff --git a/core/src/test/scala/unit/kafka/log/LogTestUtils.scala 
b/core/src/test/scala/unit/kafka/log/LogTestUtils.scala
index f6b58d78ce..50af76f556 100644
--- a/core/src/test/scala/unit/kafka/log/LogTestUtils.scala
+++ b/core/src/test/scala/unit/kafka/log/LogTestUtils.scala
@@ -28,6 +28,7 @@ import org.apache.kafka.common.utils.{Time, Utils}
 import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse}
 
 import java.nio.file.Files
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
 import scala.collection.Iterable
 import scala.jdk.CollectionConverters._
 
@@ -83,7 +84,8 @@ object LogTestUtils {
                 producerIdExpirationCheckIntervalMs: Int = 
LogManager.ProducerIdExpirationCheckIntervalMs,
                 lastShutdownClean: Boolean = true,
                 topicId: Option[Uuid] = None,
-                keepPartitionMetadataFile: Boolean = true): UnifiedLog = {
+                keepPartitionMetadataFile: Boolean = true,
+                numRemainingSegments: ConcurrentMap[String, Int] = new 
ConcurrentHashMap[String, Int]): UnifiedLog = {
     UnifiedLog(
       dir = dir,
       config = config,
@@ -98,7 +100,8 @@ object LogTestUtils {
       logDirFailureChannel = new LogDirFailureChannel(10),
       lastShutdownClean = lastShutdownClean,
       topicId = topicId,
-      keepPartitionMetadataFile = keepPartitionMetadataFile
+      keepPartitionMetadataFile = keepPartitionMetadataFile,
+      numRemainingSegments = numRemainingSegments
     )
   }
 
diff --git a/core/src/test/scala/unit/kafka/utils/TestUtils.scala 
b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
index 5a8d43795a..c49a7bdde0 100755
--- a/core/src/test/scala/unit/kafka/utils/TestUtils.scala
+++ b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
@@ -1281,13 +1281,14 @@ object TestUtils extends Logging {
                        configRepository: ConfigRepository = new 
MockConfigRepository,
                        cleanerConfig: CleanerConfig = 
CleanerConfig(enableCleaner = false),
                        time: MockTime = new MockTime(),
-                       interBrokerProtocolVersion: MetadataVersion = 
MetadataVersion.latest): LogManager = {
+                       interBrokerProtocolVersion: MetadataVersion = 
MetadataVersion.latest,
+                       recoveryThreadsPerDataDir: Int = 4): LogManager = {
     new LogManager(logDirs = logDirs.map(_.getAbsoluteFile),
                    initialOfflineDirs = Array.empty[File],
                    configRepository = configRepository,
                    initialDefaultConfig = defaultConfig,
                    cleanerConfig = cleanerConfig,
-                   recoveryThreadsPerDataDir = 4,
+                   recoveryThreadsPerDataDir = recoveryThreadsPerDataDir,
                    flushCheckMs = 1000L,
                    flushRecoveryOffsetCheckpointMs = 10000L,
                    flushStartOffsetCheckpointMs = 10000L,
diff --git 
a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/server/CheckpointBench.java 
b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/server/CheckpointBench.java
index 3bf65afc22..99fb814327 100644
--- 
a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/server/CheckpointBench.java
+++ 
b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/server/CheckpointBench.java
@@ -108,7 +108,7 @@ public class CheckpointBench {
         this.logManager = 
TestUtils.createLogManager(JavaConverters.asScalaBuffer(files),
                 LogConfig.apply(), new MockConfigRepository(), 
CleanerConfig.apply(1, 4 * 1024 * 1024L, 0.9d,
                         1024 * 1024, 32 * 1024 * 1024,
-                        Double.MAX_VALUE, 15 * 1000, true, "MD5"), time, 
MetadataVersion.latest());
+                        Double.MAX_VALUE, 15 * 1000, true, "MD5"), time, 
MetadataVersion.latest(), 4);
         scheduler.startup();
         final BrokerTopicStats brokerTopicStats = new BrokerTopicStats();
         final MetadataCache metadataCache =

Reply via email to