This is an automated email from the ASF dual-hosted git repository.

divijv pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 88e784f7c6f KAFKA-15084: Remove lock contention from RemoteIndexCache 
(#13850)
88e784f7c6f is described below

commit 88e784f7c6fa7df91965076d8cb0e0e691719cdc
Author: Divij Vaidya <[email protected]>
AuthorDate: Wed Jun 21 18:22:49 2023 +0200

    KAFKA-15084: Remove lock contention from RemoteIndexCache (#13850)
    
    Use thread safe Caffeine to cache indexes fetched from RemoteTier locally. 
This PR removes a lock contention that led to higher fetch latencies as the IO 
threads spent time unnecessarily waiting on global cache lock while a single 
thread fetches the index from remote tier. See PR #13850 for details and 
rejected alternatives.
    
    Reviewers: Luke Chen <[email protected]>, Satish Duggana 
<[email protected]>
---
 LICENSE-binary                                     |   1 +
 build.gradle                                       |   1 +
 .../scala/kafka/log/remote/RemoteIndexCache.scala  | 306 ++++++++++++-------
 .../kafka/log/remote/RemoteIndexCacheTest.scala    | 333 +++++++++++++++++----
 gradle/dependencies.gradle                         |   2 +
 .../kafka/server/util/ShutdownableThread.java      |   3 +
 6 files changed, 486 insertions(+), 160 deletions(-)

diff --git a/LICENSE-binary b/LICENSE-binary
index 916d4192dca..0b6845ae328 100644
--- a/LICENSE-binary
+++ b/LICENSE-binary
@@ -206,6 +206,7 @@ This project bundles some components that are also licensed 
under the Apache
 License Version 2.0:
 
 audience-annotations-0.13.0
+caffeine-2.9.3
 commons-beanutils-1.9.4
 commons-cli-1.4
 commons-collections-3.2.2
diff --git a/build.gradle b/build.gradle
index 0ffb019cf81..62c1e701991 100644
--- a/build.gradle
+++ b/build.gradle
@@ -873,6 +873,7 @@ project(':core') {
 
 
     implementation libs.argparse4j
+    implementation libs.caffeine
     implementation libs.commonsValidator
     implementation libs.jacksonDatabind
     implementation libs.jacksonModuleScala
diff --git a/core/src/main/scala/kafka/log/remote/RemoteIndexCache.scala 
b/core/src/main/scala/kafka/log/remote/RemoteIndexCache.scala
index 38f2c7a9192..5d2601faf92 100644
--- a/core/src/main/scala/kafka/log/remote/RemoteIndexCache.scala
+++ b/core/src/main/scala/kafka/log/remote/RemoteIndexCache.scala
@@ -16,9 +16,11 @@
  */
 package kafka.log.remote
 
+import com.github.benmanes.caffeine.cache.{Cache, Caffeine, RemovalCause}
 import kafka.log.UnifiedLog
-import kafka.log.remote.RemoteIndexCache.DirName
-import kafka.utils.{CoreUtils, Logging}
+import kafka.log.remote.RemoteIndexCache.{DirName, 
remoteLogIndexCacheCleanerThread}
+import kafka.utils.CoreUtils.{inReadLock, inWriteLock}
+import kafka.utils.{CoreUtils, Logging, threadsafe}
 import org.apache.kafka.common.Uuid
 import org.apache.kafka.common.errors.CorruptRecordException
 import org.apache.kafka.common.utils.Utils
@@ -26,104 +28,172 @@ import 
org.apache.kafka.server.log.remote.storage.RemoteStorageManager.IndexType
 import org.apache.kafka.server.log.remote.storage.{RemoteLogSegmentMetadata, 
RemoteStorageManager}
 import org.apache.kafka.storage.internals.log.{LogFileUtils, OffsetIndex, 
OffsetPosition, TimeIndex, TransactionIndex}
 import org.apache.kafka.server.util.ShutdownableThread
-import java.io.{Closeable, File, InputStream}
-import java.nio.file.{Files, Path}
-import java.util
+
+import java.io.{File, InputStream}
+import java.nio.file.{FileAlreadyExistsException, Files, Path}
 import java.util.concurrent.LinkedBlockingQueue
+import java.util.concurrent.atomic.AtomicBoolean
 import java.util.concurrent.locks.ReentrantReadWriteLock
 
 object RemoteIndexCache {
   val DirName = "remote-log-index-cache"
   val TmpFileSuffix = ".tmp"
+  val remoteLogIndexCacheCleanerThread = "remote-log-index-cleaner"
 }
 
-class Entry(val offsetIndex: OffsetIndex, val timeIndex: TimeIndex, val 
txnIndex: TransactionIndex) {
-  private var markedForCleanup: Boolean = false
+@threadsafe
+class Entry(val offsetIndex: OffsetIndex, val timeIndex: TimeIndex, val 
txnIndex: TransactionIndex) extends AutoCloseable {
+  // visible for testing
+  private[remote] var markedForCleanup = false
+  // visible for testing
+  private[remote] var cleanStarted = false
+  // This lock is used to synchronize cleanup methods and read methods. This 
ensures that cleanup (which changes the
+  // underlying files of the index) isn't performed while a read is 
in-progress for the entry. This is required in
+  // addition to using the thread safe cache because, while the thread safety 
of the cache ensures that we can read
+  // entries concurrently, it does not ensure that we won't mutate underlying 
files beloging to an entry.
   private val lock: ReentrantReadWriteLock = new ReentrantReadWriteLock()
 
   def lookupOffset(targetOffset: Long): OffsetPosition = {
-    CoreUtils.inLock(lock.readLock()) {
+    inReadLock(lock) {
       if (markedForCleanup) throw new IllegalStateException("This entry is 
marked for cleanup")
-      else offsetIndex.lookup(targetOffset)
+      offsetIndex.lookup(targetOffset)
     }
   }
 
   def lookupTimestamp(timestamp: Long, startingOffset: Long): OffsetPosition = 
{
-    CoreUtils.inLock(lock.readLock()) {
+    inReadLock(lock) {
       if (markedForCleanup) throw new IllegalStateException("This entry is 
marked for cleanup")
-
       val timestampOffset = timeIndex.lookup(timestamp)
       offsetIndex.lookup(math.max(startingOffset, timestampOffset.offset))
     }
   }
 
-  def markForCleanup(): Unit = {
-    CoreUtils.inLock(lock.writeLock()) {
+  private[remote] def markForCleanup(): Unit = {
+    inWriteLock(lock) {
       if (!markedForCleanup) {
         markedForCleanup = true
         Array(offsetIndex, timeIndex).foreach(index =>
           index.renameTo(new File(Utils.replaceSuffix(index.file.getPath, "", 
LogFileUtils.DELETED_FILE_SUFFIX))))
+        // txn index needs to be renamed separately since it's not of type 
AbstractIndex
         txnIndex.renameTo(new File(Utils.replaceSuffix(txnIndex.file.getPath, 
"",
           LogFileUtils.DELETED_FILE_SUFFIX)))
       }
     }
   }
 
-  def cleanup(): Unit = {
-    markForCleanup()
-    CoreUtils.tryAll(Seq(() => offsetIndex.deleteIfExists(), () => 
timeIndex.deleteIfExists(), () => txnIndex.deleteIfExists()))
+  /**
+   * Deletes the index files from the disk. Invoking #close is not required 
prior to this function.
+   */
+  private[remote] def cleanup(): Unit = {
+    inWriteLock(lock) {
+      markForCleanup()
+      // no-op if clean is done already
+      if (!cleanStarted) {
+        cleanStarted = true
+        CoreUtils.tryAll(Seq(() => offsetIndex.deleteIfExists(), () => 
timeIndex.deleteIfExists(), () => txnIndex.deleteIfExists()))
+      }
+    }
   }
 
+  /**
+   * Calls the underlying close method for each index which may lead to 
releasing resources such as mmap.
+   * This function does not delete the index files.
+   */
+  @Override
   def close(): Unit = {
-    Array(offsetIndex, timeIndex).foreach(index => try {
-      index.close()
-    } catch {
-      case _: Exception => // ignore error.
-    })
-    Utils.closeQuietly(txnIndex, "Closing the transaction index.")
+    inWriteLock(lock) {
+      // close is no-op if entry is already marked for cleanup. Mmap resources 
are released during cleanup.
+      if (!markedForCleanup) {
+        Utils.closeQuietly(offsetIndex, "Closing the offset index.")
+        Utils.closeQuietly(timeIndex, "Closing the time index.")
+        Utils.closeQuietly(txnIndex, "Closing the transaction index.")
+      }
+    }
   }
 }
 
 /**
- * This is a LRU cache of remote index files stored in 
`$logdir/remote-log-index-cache`. This is helpful to avoid
- * re-fetching the index files like offset, time indexes from the remote 
storage for every fetch call.
+ * This is a LFU (Least Frequently Used) cache of remote index files stored in 
`$logdir/remote-log-index-cache`.
+ * This is helpful to avoid re-fetching the index files like offset, time 
indexes from the remote storage for every
+ * fetch call. The cache is re-initialized from the index files on disk on 
startup, if the index files are available.
+ *
+ * The cache contains a garbage collection thread which will delete the files 
for entries that have been removed from
+ * the cache.
+ *
+ * Note that closing this cache does not delete the index files on disk.
+ * Note that the cache eviction policy is based on the default implementation 
of Caffeine i.e.
+ * <a href="https://github.com/ben-manes/caffeine/wiki/Efficiency";>Window 
TinyLfu</a>. TinyLfu relies on a frequency
+ * sketch to probabilistically estimate the historic usage of an entry.
  *
  * @param maxSize              maximum number of segment index entries to be 
cached.
  * @param remoteStorageManager RemoteStorageManager instance, to be used in 
fetching indexes.
  * @param logDir               log directory
  */
+@threadsafe
 class RemoteIndexCache(maxSize: Int = 1024, remoteStorageManager: 
RemoteStorageManager, logDir: String)
-  extends Logging with Closeable {
-
-  val cacheDir = new File(logDir, DirName)
-  @volatile var closed = false
-
-  val expiredIndexes = new LinkedBlockingQueue[Entry]()
-  val lock = new Object()
-
-  val entries: util.Map[Uuid, Entry] = new java.util.LinkedHashMap[Uuid, 
Entry](maxSize / 2,
-    0.75f, true) {
-    override def removeEldestEntry(eldest: util.Map.Entry[Uuid, Entry]): 
Boolean = {
-      if (this.size() > maxSize) {
-        val entry = eldest.getValue
-        // Mark the entries for cleanup, background thread will clean them 
later.
-        entry.markForCleanup()
-        expiredIndexes.add(entry)
-        true
-      } else {
-        false
+  extends Logging with AutoCloseable {
+  /**
+   * Directory where the index files will be stored on disk.
+   */
+  private val cacheDir = new File(logDir, DirName)
+  /**
+   * Represents if the cache is closed or not. Closing the cache is an 
irreversible operation.
+   */
+  private val isRemoteIndexCacheClosed: AtomicBoolean = new 
AtomicBoolean(false)
+  /**
+   * Unbounded queue containing the removed entries from the cache which are 
waiting to be garbage collected.
+   *
+   * Visible for testing
+   */
+  private[remote] val expiredIndexes = new LinkedBlockingQueue[Entry]()
+  /**
+   * Lock used to synchronize close with other read operations. This ensures 
that when we close, we don't have any other
+   * concurrent reads in-progress.
+   */
+  private val lock: ReentrantReadWriteLock = new ReentrantReadWriteLock()
+  /**
+   * Actual cache implementation that this file wraps around.
+   *
+   * The requirements for this internal cache is as follows:
+   * 1. Multiple threads should be able to read concurrently.
+   * 2. Fetch for missing keys should not block read for available keys.
+   * 3. Only one thread should fetch for a specific key.
+   * 4. Should support LRU-like policy.
+   *
+   * We use [[Caffeine]] cache instead of implementing a thread safe LRU cache 
on our own.
+   *
+   * Visible for testing.
+   */
+  private[remote] var internalCache: Cache[Uuid, Entry] = Caffeine.newBuilder()
+    .maximumSize(maxSize)
+    // removeListener is invoked when either the entry is invalidated (means 
manual removal by the caller) or
+    // evicted (means removal due to the policy)
+    .removalListener((_: Uuid, entry: Entry, _: RemovalCause) => {
+      // Mark the entries for cleanup and add them to the queue to be garbage 
collected later by the background thread.
+      entry.markForCleanup()
+      if (!expiredIndexes.offer(entry)) {
+        error(s"Error while inserting entry $entry into the cleaner queue")
       }
-    }
-  }
+    })
+    .build[Uuid, Entry]()
 
   private def init(): Unit = {
-    if (cacheDir.mkdir())
-      info(s"Created $cacheDir successfully")
+    try {
+      Files.createDirectory(cacheDir.toPath)
+      info(s"Created new file $cacheDir for RemoteIndexCache")
+    } catch {
+      case _: FileAlreadyExistsException =>
+        info(s"RemoteIndexCache directory $cacheDir already exists. Re-using 
the same directory.")
+      case e: Exception =>
+        error(s"Unable to create directory $cacheDir for RemoteIndexCache.", e)
+        throw e
+    }
 
     // Delete any .deleted files remained from the earlier run of the broker.
     Files.list(cacheDir.toPath).forEach((path: Path) => {
       if (path.endsWith(LogFileUtils.DELETED_FILE_SUFFIX)) {
-        Files.deleteIfExists(path)
+        if (Files.deleteIfExists(path))
+          debug(s"Deleted file $path on cache initialization")
       }
     })
 
@@ -136,12 +206,16 @@ class RemoteIndexCache(maxSize: Int = 1024, 
remoteStorageManager: RemoteStorageM
       val offset = name.substring(0, firstIndex).toInt
       val uuid = Uuid.fromString(name.substring(firstIndex + 1, 
name.lastIndexOf('_')))
 
-      if(!entries.containsKey(uuid)) {
+      // It is safe to update the internalCache non-atomically here since this 
function is always called by a single
+      // thread only.
+      if (!internalCache.asMap().containsKey(uuid)) {
         val offsetIndexFile = new File(cacheDir, name + 
UnifiedLog.IndexFileSuffix)
         val timestampIndexFile = new File(cacheDir, name + 
UnifiedLog.TimeIndexFileSuffix)
         val txnIndexFile = new File(cacheDir, name + 
UnifiedLog.TxnIndexFileSuffix)
 
-        if (offsetIndexFile.exists() && timestampIndexFile.exists() && 
txnIndexFile.exists()) {
+        if (Files.exists(offsetIndexFile.toPath) &&
+            Files.exists(timestampIndexFile.toPath) &&
+            Files.exists(txnIndexFile.toPath)) {
 
           val offsetIndex = new OffsetIndex(offsetIndexFile, offset, 
Int.MaxValue, false)
           offsetIndex.sanityCheck()
@@ -152,8 +226,7 @@ class RemoteIndexCache(maxSize: Int = 1024, 
remoteStorageManager: RemoteStorageM
           val txnIndex = new TransactionIndex(offset, txnIndexFile)
           txnIndex.sanityCheck()
 
-          val entry = new Entry(offsetIndex, timeIndex, txnIndex)
-          entries.put(uuid, entry)
+          internalCache.put(uuid, new Entry(offsetIndex, timeIndex, txnIndex))
         } else {
           // Delete all of them if any one of those indexes is not available 
for a specific segment id
           Files.deleteIfExists(offsetIndexFile.toPath)
@@ -167,64 +240,76 @@ class RemoteIndexCache(maxSize: Int = 1024, 
remoteStorageManager: RemoteStorageM
   init()
 
   // Start cleaner thread that will clean the expired entries
-  val cleanerThread: ShutdownableThread = new 
ShutdownableThread("remote-log-index-cleaner") {
+  private[remote] var cleanerThread: ShutdownableThread = new 
ShutdownableThread(remoteLogIndexCacheCleanerThread) {
     setDaemon(true)
 
     override def doWork(): Unit = {
-      while (!closed) {
-        try {
+      try {
+        while (!isRemoteIndexCacheClosed.get()) {
           val entry = expiredIndexes.take()
-          info(s"Cleaning up index entry $entry")
+          debug(s"Cleaning up index entry $entry")
           entry.cleanup()
-        } catch {
-          case ex: InterruptedException => info("Cleaner thread was 
interrupted", ex)
-          case ex: Exception => error("Error occurred while fetching/cleaning 
up expired entry", ex)
         }
+      } catch {
+        case ex: InterruptedException =>
+          // cleaner thread should only be interrupted when cache is being 
closed, else it's an error
+          if (!isRemoteIndexCacheClosed.get()) {
+            error("Cleaner thread received interruption but remote index cache 
is not closed", ex)
+            throw ex
+          } else {
+            debug("Cleaner thread was interrupted on cache shutdown")
+          }
+        case ex: Exception => error("Error occurred while fetching/cleaning up 
expired entry", ex)
       }
     }
   }
+
   cleanerThread.start()
 
   def getIndexEntry(remoteLogSegmentMetadata: RemoteLogSegmentMetadata): Entry 
= {
-    if(closed) throw new IllegalStateException("Instance is already closed.")
-
-    def loadIndexFile[T](fileName: String,
-                         suffix: String,
-                         fetchRemoteIndex: RemoteLogSegmentMetadata => 
InputStream,
-                         readIndex: File => T): T = {
-      val indexFile = new File(cacheDir, fileName + suffix)
-
-      def fetchAndCreateIndex(): T = {
-        val tmpIndexFile = new File(cacheDir, fileName + suffix + 
RemoteIndexCache.TmpFileSuffix)
-
-        val inputStream = fetchRemoteIndex(remoteLogSegmentMetadata)
-        try {
-          Files.copy(inputStream, tmpIndexFile.toPath)
-        } finally {
-          if (inputStream != null) {
-            inputStream.close()
-          }
-        }
+    if (isRemoteIndexCacheClosed.get()) {
+      throw new IllegalStateException(s"Unable to fetch index for " +
+        s"segment id=${remoteLogSegmentMetadata.remoteLogSegmentId().id()}. 
Index instance is already closed.")
+    }
 
-        Utils.atomicMoveWithFallback(tmpIndexFile.toPath, indexFile.toPath, 
false)
-        readIndex(indexFile)
-      }
+    inReadLock(lock) {
+      val cacheKey = remoteLogSegmentMetadata.remoteLogSegmentId().id()
+      internalCache.get(cacheKey, (uuid: Uuid) => {
+        def loadIndexFile[T](fileName: String,
+                             suffix: String,
+                             fetchRemoteIndex: RemoteLogSegmentMetadata => 
InputStream,
+                             readIndex: File => T): T = {
+          val indexFile = new File(cacheDir, fileName + suffix)
+
+          def fetchAndCreateIndex(): T = {
+            val tmpIndexFile = new File(cacheDir, fileName + suffix + 
RemoteIndexCache.TmpFileSuffix)
+
+            val inputStream = fetchRemoteIndex(remoteLogSegmentMetadata)
+            try {
+              Files.copy(inputStream, tmpIndexFile.toPath)
+            } finally {
+              if (inputStream != null) {
+                inputStream.close()
+              }
+            }
+
+            Utils.atomicMoveWithFallback(tmpIndexFile.toPath, 
indexFile.toPath, false)
+            readIndex(indexFile)
+          }
 
-      if (indexFile.exists()) {
-        try {
-          readIndex(indexFile)
-        } catch {
-          case ex: CorruptRecordException =>
-            info("Error occurred while loading the stored index", ex)
+          if (Files.exists(indexFile.toPath)) {
+            try {
+              readIndex(indexFile)
+            } catch {
+              case ex: CorruptRecordException =>
+                info(s"Error occurred while loading the stored index at 
${indexFile.toPath}", ex)
+                fetchAndCreateIndex()
+            }
+          } else {
             fetchAndCreateIndex()
+          }
         }
-      } else {
-        fetchAndCreateIndex()
-      }
-    }
 
-    lock synchronized {
-      
entries.computeIfAbsent(remoteLogSegmentMetadata.remoteLogSegmentId().id(), 
(uuid: Uuid) => {
         val startOffset = remoteLogSegmentMetadata.startOffset()
         // uuid.toString uses URL encoding which is safe for filenames and 
URLs.
         val fileName = startOffset.toString + "_" + uuid.toString + "_"
@@ -259,20 +344,39 @@ class RemoteIndexCache(maxSize: Int = 1024, 
remoteStorageManager: RemoteStorageM
   }
 
   def lookupOffset(remoteLogSegmentMetadata: RemoteLogSegmentMetadata, offset: 
Long): Int = {
-    getIndexEntry(remoteLogSegmentMetadata).lookupOffset(offset).position
+    inReadLock(lock) {
+      getIndexEntry(remoteLogSegmentMetadata).lookupOffset(offset).position
+    }
   }
 
   def lookupTimestamp(remoteLogSegmentMetadata: RemoteLogSegmentMetadata, 
timestamp: Long, startingOffset: Long): Int = {
-    getIndexEntry(remoteLogSegmentMetadata).lookupTimestamp(timestamp, 
startingOffset).position
+    inReadLock(lock) {
+      getIndexEntry(remoteLogSegmentMetadata).lookupTimestamp(timestamp, 
startingOffset).position
+    }
   }
 
+  /**
+   * Close should synchronously cleanup the resources used by this cache.
+   * This index is closed when [[RemoteLogManager]] is closed.
+   */
   def close(): Unit = {
-    closed = true
-    cleanerThread.shutdown()
-    // Close all the opened indexes.
-    lock synchronized {
-      entries.values().stream().forEach(entry => entry.close())
+    // make close idempotent and ensure no more reads allowed from henceforth. 
The in-progress reads will continue to
+    // completion (release the read lock) and then close will begin executing. 
Cleaner thread will immediately stop work.
+    if (!isRemoteIndexCacheClosed.getAndSet(true)) {
+      inWriteLock(lock) {
+        info(s"Close initiated for RemoteIndexCache. Cache 
stats=${internalCache.stats}. " +
+          s"Cache entries pending delete=${expiredIndexes.size()}")
+        // Initiate shutdown for cleaning thread
+        val shutdownRequired = cleanerThread.initiateShutdown()
+        // Close all the opened indexes to force unload mmap memory. This does 
not delete the index files from disk.
+        internalCache.asMap().forEach((_, entry) => entry.close())
+        // wait for cleaner thread to shutdown
+        if (shutdownRequired) cleanerThread.awaitShutdown()
+        // Note that internal cache does not require explicit cleaning / 
closing. We don't want to invalidate or cleanup
+        // the cache as both would lead to triggering of removal listener.
+        internalCache = null
+        info(s"Close completed for RemoteIndexCache")
+      }
     }
   }
-
 }
diff --git 
a/core/src/test/scala/unit/kafka/log/remote/RemoteIndexCacheTest.scala 
b/core/src/test/scala/unit/kafka/log/remote/RemoteIndexCacheTest.scala
index 98e3ab8c8e8..5426984acac 100644
--- a/core/src/test/scala/unit/kafka/log/remote/RemoteIndexCacheTest.scala
+++ b/core/src/test/scala/unit/kafka/log/remote/RemoteIndexCacheTest.scala
@@ -21,56 +21,61 @@ import org.apache.kafka.common.{TopicIdPartition, 
TopicPartition, Uuid}
 import 
org.apache.kafka.server.log.remote.storage.RemoteStorageManager.IndexType
 import org.apache.kafka.server.log.remote.storage.{RemoteLogSegmentId, 
RemoteLogSegmentMetadata, RemoteStorageManager}
 import org.apache.kafka.server.util.MockTime
-import org.apache.kafka.storage.internals.log.{OffsetIndex, OffsetPosition, 
TimeIndex}
-import org.apache.kafka.test.TestUtils
-import org.junit.jupiter.api.Assertions._
+import org.apache.kafka.storage.internals.log.{LogFileUtils, OffsetIndex, 
OffsetPosition, TimeIndex, TransactionIndex}
+import kafka.utils.TestUtils
+import org.apache.kafka.common.utils.Utils
+import org.junit.jupiter.api.Assertions.{assertTrue, _}
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
 import org.mockito.ArgumentMatchers
 import org.mockito.ArgumentMatchers.any
 import org.mockito.Mockito._
+import org.slf4j.{Logger, LoggerFactory}
 
 import java.io.{File, FileInputStream}
 import java.nio.file.Files
 import java.util.Collections
+import java.util.concurrent.{CountDownLatch, Executors, TimeUnit}
 import scala.collection.mutable
 
 class RemoteIndexCacheTest {
-
-  val time = new MockTime()
-  val partition = new TopicPartition("foo", 0)
-  val idPartition = new TopicIdPartition(Uuid.randomUuid(), partition)
-  val logDir: File = TestUtils.tempDirectory("kafka-logs")
-  val tpDir: File = new File(logDir, partition.toString)
-  val brokerId = 1
-  val baseOffset = 45L
-  val lastOffset = 75L
-  val segmentSize = 1024
-
-  val rsm: RemoteStorageManager = mock(classOf[RemoteStorageManager])
-  val cache: RemoteIndexCache =  new RemoteIndexCache(remoteStorageManager = 
rsm, logDir = logDir.toString)
-  val remoteLogSegmentId = new RemoteLogSegmentId(idPartition, 
Uuid.randomUuid())
-  val rlsMetadata: RemoteLogSegmentMetadata = new 
RemoteLogSegmentMetadata(remoteLogSegmentId, baseOffset, lastOffset,
-    time.milliseconds(), brokerId, time.milliseconds(), segmentSize, 
Collections.singletonMap(0, 0L))
+  private val logger: Logger = 
LoggerFactory.getLogger(classOf[RemoteIndexCacheTest])
+  private val time = new MockTime()
+  private val partition = new TopicPartition("foo", 0)
+  private val brokerId = 1
+  private val baseOffset = 45L
+  private val lastOffset = 75L
+  private val segmentSize = 1024
+  private val rsm: RemoteStorageManager = mock(classOf[RemoteStorageManager])
+  private var cache: RemoteIndexCache = _
+  private var rlsMetadata: RemoteLogSegmentMetadata = _
+  private var logDir: File = _
+  private var tpDir: File = _
 
   @BeforeEach
   def setup(): Unit = {
+    val idPartition = new TopicIdPartition(Uuid.randomUuid(), partition)
+    logDir = TestUtils.tempDir()
+    tpDir = new File(logDir, idPartition.toString)
     Files.createDirectory(tpDir.toPath)
-    val txnIdxFile = new File(tpDir, "txn-index" + 
UnifiedLog.TxnIndexFileSuffix)
-    txnIdxFile.createNewFile()
+
+    val remoteLogSegmentId = new RemoteLogSegmentId(idPartition, 
Uuid.randomUuid())
+    rlsMetadata = new RemoteLogSegmentMetadata(remoteLogSegmentId, baseOffset, 
lastOffset,
+      time.milliseconds(), brokerId, time.milliseconds(), segmentSize, 
Collections.singletonMap(0, 0L))
+
+    cache = new RemoteIndexCache(remoteStorageManager = rsm, logDir = 
logDir.toString)
+
     when(rsm.fetchIndex(any(classOf[RemoteLogSegmentMetadata]), 
any(classOf[IndexType])))
       .thenAnswer(ans => {
         val metadata = ans.getArgument[RemoteLogSegmentMetadata](0)
         val indexType = ans.getArgument[IndexType](1)
-        val maxEntries = (metadata.endOffset() - 
metadata.startOffset()).asInstanceOf[Int]
-        val offsetIdx = new OffsetIndex(new File(tpDir, 
String.valueOf(metadata.startOffset()) + UnifiedLog.IndexFileSuffix),
-          metadata.startOffset(), maxEntries * 8)
-        val timeIdx = new TimeIndex(new File(tpDir, 
String.valueOf(metadata.startOffset()) + UnifiedLog.TimeIndexFileSuffix),
-          metadata.startOffset(), maxEntries * 12)
+        val offsetIdx = createOffsetIndexForSegmentMetadata(metadata)
+        val timeIdx = createTimeIndexForSegmentMetadata(metadata)
+        val trxIdx = createTxIndexForSegmentMetadata(metadata)
         maybeAppendIndexEntries(offsetIdx, timeIdx)
         indexType match {
           case IndexType.OFFSET => new FileInputStream(offsetIdx.file)
           case IndexType.TIMESTAMP => new FileInputStream(timeIdx.file)
-          case IndexType.TRANSACTION => new FileInputStream(txnIdxFile)
+          case IndexType.TRANSACTION => new FileInputStream(trxIdx.file)
           case IndexType.LEADER_EPOCH => // leader-epoch-cache is not accessed.
           case IndexType.PRODUCER_SNAPSHOT => // producer-snapshot is not 
accessed.
         }
@@ -80,8 +85,13 @@ class RemoteIndexCacheTest {
   @AfterEach
   def cleanup(): Unit = {
     reset(rsm)
-    cache.entries.forEach((_, v) => v.cleanup())
-    cache.close()
+    // the files created for the test will be deleted automatically on thread 
exit since we use temp dir
+    Utils.closeQuietly(cache, "RemoteIndexCache created for unit test")
+    // best effort to delete the per-test resource. Even if we don't delete, 
it is ok because the parent directory
+    // will be deleted at the end of test.
+    Utils.delete(logDir)
+    // Verify no lingering threads
+    
TestUtils.assertNoNonDaemonThreads(RemoteIndexCache.remoteLogIndexCacheCleanerThread)
   }
 
   @Test
@@ -117,51 +127,60 @@ class RemoteIndexCacheTest {
 
   @Test
   def testCacheEntryExpiry(): Unit = {
-    val cache = new RemoteIndexCache(maxSize = 2, rsm, logDir = 
logDir.toString)
+    // close existing cache created in test setup before creating a new one
+    Utils.closeQuietly(cache, "RemoteIndexCache created for unit test")
+    cache = new RemoteIndexCache(maxSize = 2, rsm, logDir = logDir.toString)
     val tpId = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 0))
     val metadataList = generateRemoteLogSegmentMetadata(size = 3, tpId)
 
-    assertEquals(0, cache.entries.size())
+    assertCacheSize(0)
     // getIndex for first time will call rsm#fetchIndex
     cache.getIndexEntry(metadataList.head)
-    assertEquals(1, cache.entries.size())
+    assertCacheSize(1)
     // Calling getIndex on the same entry should not call rsm#fetchIndex 
again, but it should retrieve from cache
     cache.getIndexEntry(metadataList.head)
-    assertEquals(1, cache.entries.size())
+    assertCacheSize(1)
     verifyFetchIndexInvocation(count = 1)
 
     // Here a new key metadataList(1) is invoked, that should call 
rsm#fetchIndex, making the count to 2
     cache.getIndexEntry(metadataList.head)
     cache.getIndexEntry(metadataList(1))
-    assertEquals(2, cache.entries.size())
+    assertCacheSize(2)
     verifyFetchIndexInvocation(count = 2)
 
-    // getting index for metadataList.last should call rsm#fetchIndex, but 
metadataList(1) is already in cache.
-    cache.getIndexEntry(metadataList.last)
-    cache.getIndexEntry(metadataList(1))
-    assertEquals(2, cache.entries.size())
-    
assertTrue(cache.entries.containsKey(metadataList.last.remoteLogSegmentId().id()))
-    
assertTrue(cache.entries.containsKey(metadataList(1).remoteLogSegmentId().id()))
+    // Getting index for metadataList.last should call rsm#fetchIndex
+    // to populate this entry one of the other 2 entries will be evicted. We 
don't know which one since it's based on
+    // a probabilistic formula for Window TinyLfu. See docs for 
RemoteIndexCache
+    assertNotNull(cache.getIndexEntry(metadataList.last))
+    assertAtLeastOnePresent(cache, metadataList(1).remoteLogSegmentId().id(), 
metadataList.head.remoteLogSegmentId().id())
+    assertCacheSize(2)
     verifyFetchIndexInvocation(count = 3)
 
-    // getting index for metadataList.head should call rsm#fetchIndex as that 
entry was expired earlier,
-    // but metadataList(1) is already in cache.
-    cache.getIndexEntry(metadataList(1))
-    cache.getIndexEntry(metadataList.head)
-    assertEquals(2, cache.entries.size())
-    
assertFalse(cache.entries.containsKey(metadataList.last.remoteLogSegmentId().id()))
+    // getting index for last expired entry should call rsm#fetchIndex as that 
entry was expired earlier
+    val missingEntryOpt = {
+      metadataList.find(segmentMetadata => {
+        val segmentId = segmentMetadata.remoteLogSegmentId().id()
+        !cache.internalCache.asMap().containsKey(segmentId)
+      })
+    }
+    assertFalse(missingEntryOpt.isEmpty)
+    cache.getIndexEntry(missingEntryOpt.get)
+    assertCacheSize(2)
     verifyFetchIndexInvocation(count = 4)
   }
 
   @Test
   def testGetIndexAfterCacheClose(): Unit = {
-    val cache = new RemoteIndexCache(maxSize = 2, rsm, logDir = 
logDir.toString)
+    // close existing cache created in test setup before creating a new one
+    Utils.closeQuietly(cache, "RemoteIndexCache created for unit test")
+
+    cache = new RemoteIndexCache(maxSize = 2, rsm, logDir = logDir.toString)
     val tpId = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 0))
     val metadataList = generateRemoteLogSegmentMetadata(size = 3, tpId)
 
-    assertEquals(0, cache.entries.size())
+    assertCacheSize(0)
     cache.getIndexEntry(metadataList.head)
-    assertEquals(1, cache.entries.size())
+    assertCacheSize(1)
     verifyFetchIndexInvocation(count = 1)
 
     cache.close()
@@ -170,35 +189,188 @@ class RemoteIndexCacheTest {
     assertThrows(classOf[IllegalStateException], () => 
cache.getIndexEntry(metadataList.head))
   }
 
+  @Test
+  def testCloseIsIdempotent(): Unit = {
+    // generate and add entry to cache
+    val spyEntry = generateSpyCacheEntry()
+    cache.internalCache.put(rlsMetadata.remoteLogSegmentId().id(), spyEntry)
+
+    cache.close()
+    cache.close()
+
+    // verify that entry is only closed once
+    verify(spyEntry).close()
+  }
+
+  @Test
+  def testCacheEntryIsDeletedOnInvalidation(): Unit = {
+    def getIndexFileFromDisk(suffix: String) = {
+      Files.walk(tpDir.toPath)
+        .filter(Files.isRegularFile(_))
+        .filter(path => path.getFileName.toString.endsWith(suffix))
+        .findAny()
+    }
+
+    val internalIndexKey = rlsMetadata.remoteLogSegmentId().id()
+    val cacheEntry = generateSpyCacheEntry()
+
+    // verify index files on disk
+    assertTrue(getIndexFileFromDisk(UnifiedLog.IndexFileSuffix).isPresent, 
s"Offset index file should be present on disk at ${tpDir.toPath}")
+    assertTrue(getIndexFileFromDisk(UnifiedLog.TxnIndexFileSuffix).isPresent, 
s"Txn index file should be present on disk at ${tpDir.toPath}")
+    assertTrue(getIndexFileFromDisk(UnifiedLog.TimeIndexFileSuffix).isPresent, 
s"Time index file should be present on disk at ${tpDir.toPath}")
+
+    // add the spied entry into the cache, it will overwrite the non-spied 
entry
+    cache.internalCache.put(internalIndexKey, cacheEntry)
+
+    // no expired entries yet
+    assertEquals(0, cache.expiredIndexes.size, "expiredIndex queue should be 
zero at start of test")
+
+    // invalidate the cache. it should async mark the entry for removal
+    cache.internalCache.invalidate(internalIndexKey)
+
+    // wait until entry is marked for deletion
+    TestUtils.waitUntilTrue(() => cacheEntry.markedForCleanup,
+      "Failed to mark cache entry for cleanup after invalidation")
+    TestUtils.waitUntilTrue(() => cacheEntry.cleanStarted,
+      "Failed to cleanup cache entry after invalidation")
+
+    // first it will be marked for cleanup, second time markForCleanup is 
called when cleanup() is called
+    verify(cacheEntry, times(2)).markForCleanup()
+    // after that async it will be cleaned up
+    verify(cacheEntry).cleanup()
+
+    // verify that index(s) rename is only called 1 time
+    verify(cacheEntry.timeIndex).renameTo(any(classOf[File]))
+    verify(cacheEntry.offsetIndex).renameTo(any(classOf[File]))
+    verify(cacheEntry.txnIndex).renameTo(any(classOf[File]))
+
+    // verify no index files on disk
+    assertFalse(getIndexFileFromDisk(UnifiedLog.IndexFileSuffix).isPresent,
+      s"Offset index file should not be present on disk at ${tpDir.toPath}")
+    assertFalse(getIndexFileFromDisk(UnifiedLog.TxnIndexFileSuffix).isPresent,
+      s"Txn index file should not be present on disk at ${tpDir.toPath}")
+    assertFalse(getIndexFileFromDisk(UnifiedLog.TimeIndexFileSuffix).isPresent,
+      s"Time index file should not be present on disk at ${tpDir.toPath}")
+    
assertFalse(getIndexFileFromDisk(LogFileUtils.DELETED_FILE_SUFFIX).isPresent,
+      s"Index file marked for deletion should not be present on disk at 
${tpDir.toPath}")
+  }
+
+  @Test
+  def testClose(): Unit = {
+    val spyEntry = generateSpyCacheEntry()
+    cache.internalCache.put(rlsMetadata.remoteLogSegmentId().id(), spyEntry)
+
+    // close the cache
+    cache.close()
+
+    // closing the cache should close the entry
+    verify(spyEntry).close()
+
+    // close for all index entries must be invoked
+    verify(spyEntry.txnIndex).close()
+    verify(spyEntry.offsetIndex).close()
+    verify(spyEntry.timeIndex).close()
+
+    // index files must not be deleted
+    verify(spyEntry.txnIndex, times(0)).deleteIfExists()
+    verify(spyEntry.offsetIndex, times(0)).deleteIfExists()
+    verify(spyEntry.timeIndex, times(0)).deleteIfExists()
+
+    // verify cleaner thread is shutdown
+    assertTrue(cache.cleanerThread.isShutdownComplete)
+  }
+
+  @Test
+  def testConcurrentReadWriteAccessForCache(): Unit = {
+    val tpId = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 0))
+    val metadataList = generateRemoteLogSegmentMetadata(size = 3, tpId)
+
+    assertCacheSize(0)
+    // getIndex for first time will call rsm#fetchIndex
+    cache.getIndexEntry(metadataList.head)
+    assertCacheSize(1)
+    verifyFetchIndexInvocation(count = 1, Seq(IndexType.OFFSET, 
IndexType.TIMESTAMP))
+    reset(rsm)
+
+    // Simulate a concurrency situation where one thread is reading the entry 
already present in the cache (cache hit)
+    // and the other thread is reading an entry which is not available in the 
cache (cache miss). The expected behaviour
+    // is for the former thread to succeed while latter is fetching from rsm.
+    // In this this test we simulate the situation using latches. We perform 
the following operations:
+    // 1. Start the CacheMiss thread and wait until it starts executing the 
rsm.fetchIndex
+    // 2. Block the CacheMiss thread inside the call to rsm.fetchIndex.
+    // 3. Start the CacheHit thread. Assert that it performs a successful read.
+    // 4. On completion of successful read by CacheHit thread, signal the 
CacheMiss thread to release it's block.
+    // 5. Validate that the test passes. If the CacheMiss thread was blocking 
the CacheHit thread, the test will fail.
+    //
+    val latchForCacheHit = new CountDownLatch(1)
+    val latchForCacheMiss = new CountDownLatch(1)
+
+    val readerCacheHit = (() => {
+      // Wait for signal to start executing the read
+      logger.debug(s"Waiting for signal to begin read from 
${Thread.currentThread()}")
+      latchForCacheHit.await()
+      val entry = cache.getIndexEntry(metadataList.head)
+      assertNotNull(entry)
+      // Signal the CacheMiss to unblock itself
+      logger.debug(s"Signaling CacheMiss to unblock from 
${Thread.currentThread()}")
+      latchForCacheMiss.countDown()
+    }): Runnable
+
+    when(rsm.fetchIndex(any(classOf[RemoteLogSegmentMetadata]), 
any(classOf[IndexType])))
+      .thenAnswer(_ => {
+        logger.debug(s"Signaling CacheHit to begin read from 
${Thread.currentThread()}")
+        latchForCacheHit.countDown()
+        logger.debug(s"Waiting for signal to complete rsm fetch from 
${Thread.currentThread()}")
+        latchForCacheMiss.await()
+      })
+
+    val readerCacheMiss = (() => {
+      val entry = cache.getIndexEntry(metadataList.last)
+      assertNotNull(entry)
+    }): Runnable
+
+    val executor = Executors.newFixedThreadPool(2)
+    try {
+      executor.submit(readerCacheMiss: Runnable)
+      executor.submit(readerCacheHit: Runnable)
+      assertTrue(latchForCacheMiss.await(30, TimeUnit.SECONDS))
+    } finally {
+      executor.shutdownNow()
+    }
+  }
+
   @Test
   def testReloadCacheAfterClose(): Unit = {
-    val cache = new RemoteIndexCache(maxSize = 2, rsm, logDir = 
logDir.toString)
+    // close existing cache created in test setup before creating a new one
+    Utils.closeQuietly(cache, "RemoteIndexCache created for unit test")
+    cache = new RemoteIndexCache(maxSize = 2, rsm, logDir = logDir.toString)
     val tpId = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 0))
     val metadataList = generateRemoteLogSegmentMetadata(size = 3, tpId)
 
-    assertEquals(0, cache.entries.size())
+    assertCacheSize(0)
     // getIndex for first time will call rsm#fetchIndex
     cache.getIndexEntry(metadataList.head)
-    assertEquals(1, cache.entries.size())
+    assertCacheSize(1)
     // Calling getIndex on the same entry should not call rsm#fetchIndex 
again, but it should retrieve from cache
     cache.getIndexEntry(metadataList.head)
-    assertEquals(1, cache.entries.size())
+    assertCacheSize(1)
     verifyFetchIndexInvocation(count = 1)
 
     // Here a new key metadataList(1) is invoked, that should call 
rsm#fetchIndex, making the count to 2
     cache.getIndexEntry(metadataList(1))
-    assertEquals(2, cache.entries.size())
+    assertCacheSize(2)
     // Calling getIndex on the same entry should not call rsm#fetchIndex 
again, but it should retrieve from cache
     cache.getIndexEntry(metadataList(1))
-    assertEquals(2, cache.entries.size())
+    assertCacheSize(2)
     verifyFetchIndexInvocation(count = 2)
 
-    // Here a new key metadataList(2) is invoked, that should call 
rsm#fetchIndex, making the count to 2
+    // Here a new key metadataList(2) is invoked, that should call 
rsm#fetchIndex
+    // The cache max size is 2, it will remove one entry and keep the overall 
size to 2
     cache.getIndexEntry(metadataList(2))
-    assertEquals(2, cache.entries.size())
+    assertCacheSize(2)
     // Calling getIndex on the same entry should not call rsm#fetchIndex 
again, but it should retrieve from cache
     cache.getIndexEntry(metadataList(2))
-    assertEquals(2, cache.entries.size())
+    assertCacheSize(2)
     verifyFetchIndexInvocation(count = 3)
 
     // Close the cache
@@ -206,8 +378,33 @@ class RemoteIndexCacheTest {
 
     // Reload the cache from the disk and check the cache size is same as 
earlier
     val reloadedCache = new RemoteIndexCache(maxSize = 2, rsm, logDir = 
logDir.toString)
-    assertEquals(2, reloadedCache.entries.size())
+    assertEquals(2, reloadedCache.internalCache.asMap().size())
     reloadedCache.close()
+
+    verifyNoMoreInteractions(rsm)
+  }
+
+  private def generateSpyCacheEntry(): Entry = {
+    val timeIndex = spy(createTimeIndexForSegmentMetadata(rlsMetadata))
+    val txIndex = spy(createTxIndexForSegmentMetadata(rlsMetadata))
+    val offsetIndex = spy(createOffsetIndexForSegmentMetadata(rlsMetadata))
+    spy(new Entry(offsetIndex, timeIndex, txIndex))
+  }
+
+  private def assertAtLeastOnePresent(cache: RemoteIndexCache, uuids: Uuid*): 
Unit = {
+    uuids.foreach {
+      uuid => {
+        if (cache.internalCache.asMap().containsKey(uuid)) return
+      }
+    }
+    fail("all uuids are not present in cache")
+  }
+
+  private def assertCacheSize(expectedSize: Int): Unit = {
+    // Cache may grow beyond the size temporarily while evicting, hence, run 
in a loop to validate
+    // that cache reaches correct state eventually
+    TestUtils.waitUntilTrue(() => cache.internalCache.asMap().size() == 
expectedSize,
+      msg = s"cache did not adhere to expected size of $expectedSize")
   }
 
   private def verifyFetchIndexInvocation(count: Int,
@@ -218,6 +415,24 @@ class RemoteIndexCacheTest {
     }
   }
 
+  private def createTxIndexForSegmentMetadata(metadata: 
RemoteLogSegmentMetadata): TransactionIndex = {
+    val txnIdxFile = new File(tpDir, "txn-index" + 
UnifiedLog.TxnIndexFileSuffix)
+    txnIdxFile.createNewFile()
+    new TransactionIndex(metadata.startOffset(), txnIdxFile)
+  }
+
+  private def createTimeIndexForSegmentMetadata(metadata: 
RemoteLogSegmentMetadata): TimeIndex = {
+    val maxEntries = (metadata.endOffset() - 
metadata.startOffset()).asInstanceOf[Int]
+    new TimeIndex(new File(tpDir, String.valueOf(metadata.startOffset()) + 
UnifiedLog.TimeIndexFileSuffix),
+      metadata.startOffset(), maxEntries * 12)
+  }
+
+  private def createOffsetIndexForSegmentMetadata(metadata: 
RemoteLogSegmentMetadata) = {
+    val maxEntries = (metadata.endOffset() - 
metadata.startOffset()).asInstanceOf[Int]
+    new OffsetIndex(new File(tpDir, String.valueOf(metadata.startOffset()) + 
UnifiedLog.IndexFileSuffix),
+      metadata.startOffset(), maxEntries * 8)
+  }
+
   private def generateRemoteLogSegmentMetadata(size: Int,
                                                tpId: TopicIdPartition): 
List[RemoteLogSegmentMetadata] = {
     val metadataList = mutable.Buffer.empty[RemoteLogSegmentMetadata]
diff --git a/gradle/dependencies.gradle b/gradle/dependencies.gradle
index 56c3448b788..ca30e14e11e 100644
--- a/gradle/dependencies.gradle
+++ b/gradle/dependencies.gradle
@@ -61,6 +61,7 @@ versions += [
   apacheds: "2.0.0-M24",
   argparse4j: "0.7.0",
   bcpkix: "1.73",
+  caffeine: "2.9.3", // 3.x supports JDK 11 and above
   checkstyle: "8.36.2",
   commonsCli: "1.4",
   commonsValidator: "1.7",
@@ -145,6 +146,7 @@ libs += [
   apachedsJdbmPartition: 
"org.apache.directory.server:apacheds-jdbm-partition:$versions.apacheds",
   argparse4j: "net.sourceforge.argparse4j:argparse4j:$versions.argparse4j",
   bcpkix: "org.bouncycastle:bcpkix-jdk18on:$versions.bcpkix",
+  caffeine: "com.github.ben-manes.caffeine:caffeine:$versions.caffeine",
   commonsCli: "commons-cli:commons-cli:$versions.commonsCli",
   commonsValidator: 
"commons-validator:commons-validator:$versions.commonsValidator",
   easymock: "org.easymock:easymock:$versions.easymock",
diff --git 
a/server-common/src/main/java/org/apache/kafka/server/util/ShutdownableThread.java
 
b/server-common/src/main/java/org/apache/kafka/server/util/ShutdownableThread.java
index 4ef727d3040..d4598cc3073 100644
--- 
a/server-common/src/main/java/org/apache/kafka/server/util/ShutdownableThread.java
+++ 
b/server-common/src/main/java/org/apache/kafka/server/util/ShutdownableThread.java
@@ -76,6 +76,9 @@ public abstract class ShutdownableThread extends Thread {
         return isShutdownComplete() && !isShutdownInitiated();
     }
 
+    /**
+     * @return true if the thread hasn't initiated shutdown already
+     */
     public boolean initiateShutdown() {
         synchronized (this) {
             if (isRunning()) {


Reply via email to