This is an automated email from the ASF dual-hosted git repository.

showuon pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new e99984248da KAFKA-9550 Copying log segments to tiered storage in 
RemoteLogManager (#13487)
e99984248da is described below

commit e99984248da3042fd7fd6ed5f951f7222a4a3ccd
Author: Satish Duggana <[email protected]>
AuthorDate: Wed Apr 12 11:25:36 2023 +0530

    KAFKA-9550 Copying log segments to tiered storage in RemoteLogManager 
(#13487)
    
    Added functionality to copy log segments, indexes to the target remote 
storage for each topic partition enabled with tiered storage. This involves 
creating scheduled tasks for all leader partition replicas to copy their log 
segments in sequence to tiered storage.
    
    Reviewers: Jun Rao <[email protected]>, Luke Chen <[email protected]>
---
 checkstyle/import-control-core.xml                 |  11 +
 .../java/kafka/log/remote/RemoteLogManager.java    | 719 +++++++++++++++++++++
 core/src/main/scala/kafka/log/UnifiedLog.scala     |  13 +-
 .../scala/kafka/log/remote/RemoteLogManager.scala  | 289 ---------
 .../src/main/scala/kafka/server/BrokerServer.scala |   6 +-
 core/src/main/scala/kafka/server/KafkaServer.scala |  20 +-
 .../main/scala/kafka/server/ReplicaManager.scala   |   3 +
 .../kafka/log/remote/RemoteLogManagerTest.java     | 573 ++++++++++++++++
 .../test/scala/unit/kafka/log/UnifiedLogTest.scala |   6 +-
 .../kafka/log/remote/RemoteLogManagerTest.scala    | 277 --------
 .../apache/kafka/server/common/CheckpointFile.java |  33 +-
 .../checkpoint/InMemoryLeaderEpochCheckpoint.java  |  10 +-
 .../internals/epoch/LeaderEpochFileCache.java      |  12 +-
 .../internals/log/ProducerStateManager.java        |   4 +
 14 files changed, 1365 insertions(+), 611 deletions(-)

diff --git a/checkstyle/import-control-core.xml 
b/checkstyle/import-control-core.xml
index 0d5935f9f2b..a08563a3b5a 100644
--- a/checkstyle/import-control-core.xml
+++ b/checkstyle/import-control-core.xml
@@ -74,6 +74,17 @@
     <allow pkg="org.apache.kafka.server.util" />
   </subpackage>
 
+  <subpackage name="log.remote">
+    <allow pkg="org.apache.kafka.server.common" />
+    <allow pkg="org.apache.kafka.server.log.remote" />
+    <allow pkg="org.apache.kafka.storage.internals" />
+    <allow pkg="kafka.log" />
+    <allow pkg="kafka.cluster" />
+    <allow pkg="kafka.server" />
+    <allow pkg="org.mockito" />
+    <allow pkg="org.apache.kafka.test" />
+  </subpackage>
+
   <subpackage name="server">
     <allow pkg="kafka" />
     <allow pkg="org.apache.kafka" />
diff --git a/core/src/main/java/kafka/log/remote/RemoteLogManager.java 
b/core/src/main/java/kafka/log/remote/RemoteLogManager.java
new file mode 100644
index 00000000000..a2b7cd0b88d
--- /dev/null
+++ b/core/src/main/java/kafka/log/remote/RemoteLogManager.java
@@ -0,0 +1,719 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package kafka.log.remote;
+
+import kafka.cluster.Partition;
+import kafka.log.LogSegment;
+import kafka.log.UnifiedLog;
+import kafka.server.KafkaConfig;
+import org.apache.kafka.common.KafkaException;
+import org.apache.kafka.common.TopicIdPartition;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.Uuid;
+import org.apache.kafka.common.record.FileRecords;
+import org.apache.kafka.common.record.Record;
+import org.apache.kafka.common.record.RecordBatch;
+import org.apache.kafka.common.record.RemoteLogInputStream;
+import org.apache.kafka.common.utils.ChildFirstClassLoader;
+import org.apache.kafka.common.utils.KafkaThread;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.common.utils.Utils;
+import 
org.apache.kafka.server.log.remote.metadata.storage.ClassLoaderAwareRemoteLogMetadataManager;
+import 
org.apache.kafka.server.log.remote.storage.ClassLoaderAwareRemoteStorageManager;
+import org.apache.kafka.server.log.remote.storage.LogSegmentData;
+import org.apache.kafka.server.log.remote.storage.RemoteLogManagerConfig;
+import org.apache.kafka.server.log.remote.storage.RemoteLogMetadataManager;
+import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentId;
+import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadata;
+import 
org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadataUpdate;
+import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentState;
+import org.apache.kafka.server.log.remote.storage.RemoteStorageException;
+import org.apache.kafka.server.log.remote.storage.RemoteStorageManager;
+import 
org.apache.kafka.storage.internals.checkpoint.InMemoryLeaderEpochCheckpoint;
+import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache;
+import org.apache.kafka.storage.internals.log.EpochEntry;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import scala.Option;
+import scala.collection.JavaConverters;
+
+import java.io.Closeable;
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.lang.reflect.InvocationTargetException;
+import java.nio.ByteBuffer;
+import java.nio.file.Path;
+import java.security.AccessController;
+import java.security.PrivilegedAction;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.ListIterator;
+import java.util.Map;
+import java.util.Optional;
+import java.util.OptionalInt;
+import java.util.OptionalLong;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
+import java.util.concurrent.ScheduledFuture;
+import java.util.concurrent.ScheduledThreadPoolExecutor;
+import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Consumer;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+/**
+ * This class is responsible for
+ * - initializing `RemoteStorageManager` and `RemoteLogMetadataManager` 
instances
+ * - receives any leader and follower replica events and partition stop events 
and act on them
+ * - also provides APIs to fetch indexes, metadata about remote log segments
+ * - copying log segments to remote storage
+ */
+public class RemoteLogManager implements Closeable {
+
+    private static final Logger LOGGER = 
LoggerFactory.getLogger(RemoteLogManager.class);
+
+    private final RemoteLogManagerConfig rlmConfig;
+    private final int brokerId;
+    private final String logDir;
+    private final Time time;
+    private final Function<TopicPartition, Optional<UnifiedLog>> fetchLog;
+
+    private final RemoteStorageManager remoteLogStorageManager;
+
+    private final RemoteLogMetadataManager remoteLogMetadataManager;
+
+    private final RemoteIndexCache indexCache;
+
+    private final RLMScheduledThreadPool rlmScheduledThreadPool;
+
+    private final long delayInMs;
+
+    private final ConcurrentHashMap<TopicIdPartition, RLMTaskWithFuture> 
leaderOrFollowerTasks = new ConcurrentHashMap<>();
+
+    // topic ids that are received on leadership changes, this map is cleared 
on stop partitions
+    private final ConcurrentMap<TopicPartition, Uuid> topicPartitionIds = new 
ConcurrentHashMap<>();
+
+    private boolean closed = false;
+
+    /**
+     * Creates RemoteLogManager instance with the given arguments.
+     *
+     * @param rlmConfig Configuration required for remote logging 
subsystem(tiered storage) at the broker level.
+     * @param brokerId  id of the current broker.
+     * @param logDir    directory of Kafka log segments.
+     * @param time      Time instance.
+     * @param fetchLog  function to get UnifiedLog instance for a given topic.
+     */
+    public RemoteLogManager(RemoteLogManagerConfig rlmConfig,
+                            int brokerId,
+                            String logDir,
+                            Time time,
+                            Function<TopicPartition, Optional<UnifiedLog>> 
fetchLog) {
+
+        this.rlmConfig = rlmConfig;
+        this.brokerId = brokerId;
+        this.logDir = logDir;
+        this.time = time;
+        this.fetchLog = fetchLog;
+
+        remoteLogStorageManager = createRemoteStorageManager();
+        remoteLogMetadataManager = createRemoteLogMetadataManager();
+        indexCache = new RemoteIndexCache(1024, remoteLogStorageManager, 
logDir);
+        delayInMs = rlmConfig.remoteLogManagerTaskIntervalMs();
+        rlmScheduledThreadPool = new 
RLMScheduledThreadPool(rlmConfig.remoteLogManagerThreadPoolSize());
+    }
+
+    private <T> T createDelegate(ClassLoader classLoader, String className) {
+        try {
+            return (T) classLoader.loadClass(className)
+                    .getDeclaredConstructor().newInstance();
+        } catch (InstantiationException | IllegalAccessException | 
InvocationTargetException | NoSuchMethodException |
+                 ClassNotFoundException e) {
+            throw new KafkaException(e);
+        }
+    }
+
+    RemoteStorageManager createRemoteStorageManager() {
+        return AccessController.doPrivileged(new 
PrivilegedAction<RemoteStorageManager>() {
+            private final String classPath = 
rlmConfig.remoteStorageManagerClassPath();
+
+            public RemoteStorageManager run() {
+                if (classPath != null && !classPath.trim().isEmpty()) {
+                    ChildFirstClassLoader classLoader = new 
ChildFirstClassLoader(classPath, this.getClass().getClassLoader());
+                    RemoteStorageManager delegate = 
createDelegate(classLoader, rlmConfig.remoteStorageManagerClassName());
+                    return new ClassLoaderAwareRemoteStorageManager(delegate, 
classLoader);
+                } else {
+                    return createDelegate(this.getClass().getClassLoader(), 
rlmConfig.remoteStorageManagerClassName());
+                }
+            }
+        });
+    }
+
+    private void configureRSM() {
+        final Map<String, Object> rsmProps = new 
HashMap<>(rlmConfig.remoteStorageManagerProps());
+        rsmProps.put(KafkaConfig.BrokerIdProp(), brokerId);
+        remoteLogStorageManager.configure(rsmProps);
+    }
+
+    RemoteLogMetadataManager createRemoteLogMetadataManager() {
+        return AccessController.doPrivileged(new 
PrivilegedAction<RemoteLogMetadataManager>() {
+            private String classPath = 
rlmConfig.remoteLogMetadataManagerClassPath();
+
+            public RemoteLogMetadataManager run() {
+                if (classPath != null && !classPath.trim().isEmpty()) {
+                    ClassLoader classLoader = new 
ChildFirstClassLoader(classPath, this.getClass().getClassLoader());
+                    RemoteLogMetadataManager delegate = 
createDelegate(classLoader, rlmConfig.remoteLogMetadataManagerClassName());
+                    return new 
ClassLoaderAwareRemoteLogMetadataManager(delegate, classLoader);
+                } else {
+                    return createDelegate(this.getClass().getClassLoader(), 
rlmConfig.remoteLogMetadataManagerClassName());
+                }
+            }
+        });
+    }
+
+    private void configureRLMM() {
+        final Map<String, Object> rlmmProps = new 
HashMap<>(rlmConfig.remoteLogMetadataManagerProps());
+
+        rlmmProps.put(KafkaConfig.BrokerIdProp(), brokerId);
+        rlmmProps.put(KafkaConfig.LogDirProp(), logDir);
+        remoteLogMetadataManager.configure(rlmmProps);
+    }
+
+    public void startup() {
+        // Initialize and configure RSM and RLMM. This will start RSM, RLMM 
resources which may need to start resources
+        // in connecting to the brokers or remote storages.
+        configureRSM();
+        configureRLMM();
+    }
+
+    public RemoteStorageManager storageManager() {
+        return remoteLogStorageManager;
+    }
+
+    private Stream<Partition> filterPartitions(Set<Partition> partitions) {
+        // We are not specifically checking for internal topics etc here as 
`log.remoteLogEnabled()` already handles that.
+        return partitions.stream().filter(partition -> 
partition.log().exists(UnifiedLog::remoteLogEnabled));
+    }
+
+    private void cacheTopicPartitionIds(TopicIdPartition topicIdPartition) {
+        Uuid previousTopicId = 
topicPartitionIds.put(topicIdPartition.topicPartition(), 
topicIdPartition.topicId());
+        if (previousTopicId != null && previousTopicId != 
topicIdPartition.topicId()) {
+            LOGGER.info("Previous cached topic id {} for {} does not match 
updated topic id {}",
+                    previousTopicId, topicIdPartition.topicPartition(), 
topicIdPartition.topicId());
+        }
+    }
+
+    // for testing
+    public RLMScheduledThreadPool rlmScheduledThreadPool() {
+        return rlmScheduledThreadPool;
+    }
+
+    /**
+     * Callback to receive any leadership changes for the topic partitions 
assigned to this broker. If there are no
+     * existing tasks for a given topic partition then it will assign new 
leader or follower task else it will convert the
+     * task to respective target state(leader or follower).
+     *
+     * @param partitionsBecomeLeader   partitions that have become leaders on 
this broker.
+     * @param partitionsBecomeFollower partitions that have become followers 
on this broker.
+     * @param topicIds                 topic name to topic id mappings.
+     */
+    public void onLeadershipChange(Set<Partition> partitionsBecomeLeader,
+                                   Set<Partition> partitionsBecomeFollower,
+                                   Map<String, Uuid> topicIds) {
+        LOGGER.debug("Received leadership changes for leaders: {} and 
followers: {}", partitionsBecomeLeader, partitionsBecomeFollower);
+
+        Map<TopicIdPartition, Integer> leaderPartitionsWithLeaderEpoch = 
filterPartitions(partitionsBecomeLeader)
+                .collect(Collectors.toMap(
+                        partition -> new 
TopicIdPartition(topicIds.get(partition.topic()), partition.topicPartition()),
+                        partition -> partition.getLeaderEpoch()));
+        Set<TopicIdPartition> leaderPartitions = 
leaderPartitionsWithLeaderEpoch.keySet();
+
+        Set<TopicIdPartition> followerPartitions = 
filterPartitions(partitionsBecomeFollower)
+                .map(p -> new TopicIdPartition(topicIds.get(p.topic()), 
p.topicPartition())).collect(Collectors.toSet());
+
+        if (!leaderPartitions.isEmpty() || !followerPartitions.isEmpty()) {
+            LOGGER.debug("Effective topic partitions after filtering compact 
and internal topics, leaders: {} and followers: {}",
+                    leaderPartitions, followerPartitions);
+
+            leaderPartitions.forEach(this::cacheTopicPartitionIds);
+            followerPartitions.forEach(this::cacheTopicPartitionIds);
+
+            
remoteLogMetadataManager.onPartitionLeadershipChanges(leaderPartitions, 
followerPartitions);
+            followerPartitions.forEach(topicIdPartition ->
+                    doHandleLeaderOrFollowerPartitions(topicIdPartition, 
rlmTask -> rlmTask.convertToFollower()));
+
+            leaderPartitionsWithLeaderEpoch.forEach((topicIdPartition, 
leaderEpoch) ->
+                    doHandleLeaderOrFollowerPartitions(topicIdPartition,
+                            rlmTask -> rlmTask.convertToLeader(leaderEpoch)));
+        }
+    }
+
+    /**
+     * Deletes the internal topic partition info if delete flag is set as true.
+     *
+     * @param topicPartition topic partition to be stopped.
+     * @param delete         flag to indicate whether the given topic 
partitions to be deleted or not.
+     */
+    public void stopPartitions(TopicPartition topicPartition, boolean delete) {
+        if (delete) {
+            // Delete from internal datastructures only if it is to be deleted.
+            Uuid topicIdPartition = topicPartitionIds.remove(topicPartition);
+            LOGGER.debug("Removed partition: {} from topicPartitionIds", 
topicIdPartition);
+        }
+    }
+
+    public Optional<RemoteLogSegmentMetadata> 
fetchRemoteLogSegmentMetadata(TopicPartition topicPartition,
+                                                                            
int epochForOffset,
+                                                                            
long offset) throws RemoteStorageException {
+        Uuid topicId = topicPartitionIds.get(topicPartition);
+
+        if (topicId == null) {
+            throw new KafkaException("No topic id registered for topic 
partition: " + topicPartition);
+        }
+        return remoteLogMetadataManager.remoteLogSegmentMetadata(new 
TopicIdPartition(topicId, topicPartition), epochForOffset, offset);
+    }
+
+    private Optional<FileRecords.TimestampAndOffset> 
lookupTimestamp(RemoteLogSegmentMetadata rlsMetadata, long timestamp, long 
startingOffset)
+            throws RemoteStorageException, IOException {
+        int startPos = indexCache.lookupTimestamp(rlsMetadata, timestamp, 
startingOffset);
+
+        InputStream remoteSegInputStream = null;
+        try {
+            // Search forward for the position of the last offset that is 
greater than or equal to the startingOffset
+            remoteSegInputStream = 
remoteLogStorageManager.fetchLogSegment(rlsMetadata, startPos);
+            RemoteLogInputStream remoteLogInputStream = new 
RemoteLogInputStream(remoteSegInputStream);
+
+            while (true) {
+                RecordBatch batch = remoteLogInputStream.nextBatch();
+                if (batch == null) break;
+                if (batch.maxTimestamp() >= timestamp && batch.lastOffset() >= 
startingOffset) {
+                    for (Record record : batch) {
+                        if (record.timestamp() >= timestamp && record.offset() 
>= startingOffset)
+                            return Optional.of(new 
FileRecords.TimestampAndOffset(record.timestamp(), record.offset(), 
maybeLeaderEpoch(batch.partitionLeaderEpoch())));
+                    }
+                }
+            }
+
+            return Optional.empty();
+        } finally {
+            Utils.closeQuietly(remoteSegInputStream, 
"RemoteLogSegmentInputStream");
+        }
+    }
+
+    private Optional<Integer> maybeLeaderEpoch(int leaderEpoch) {
+        return leaderEpoch == RecordBatch.NO_PARTITION_LEADER_EPOCH ? 
Optional.empty() : Optional.of(leaderEpoch);
+    }
+
+    /**
+     * Search the message offset in the remote storage based on timestamp and 
offset.
+     * <p>
+     * This method returns an option of TimestampOffset. The returned value is 
determined using the following ordered list of rules:
+     * <p>
+     * - If there are no messages in the remote storage, return None
+     * - If all the messages in the remote storage have smaller offsets, 
return None
+     * - If all the messages in the remote storage have smaller timestamps, 
return None
+     * - Otherwise, return an option of TimestampOffset. The offset is the 
offset of the first message whose timestamp
+     * is greater than or equals to the target timestamp and whose offset is 
greater than or equals to the startingOffset.
+     *
+     * @param tp               topic partition in which the offset to be found.
+     * @param timestamp        The timestamp to search for.
+     * @param startingOffset   The starting offset to search.
+     * @param leaderEpochCache LeaderEpochFileCache of the topic partition.
+     * @return the timestamp and offset of the first message that meets the 
requirements. None will be returned if there
+     * is no such message.
+     */
+    public Optional<FileRecords.TimestampAndOffset> 
findOffsetByTimestamp(TopicPartition tp,
+                                                                          long 
timestamp,
+                                                                          long 
startingOffset,
+                                                                          
LeaderEpochFileCache leaderEpochCache) throws RemoteStorageException, 
IOException {
+        Uuid topicId = topicPartitionIds.get(tp);
+        if (topicId == null) {
+            throw new KafkaException("Topic id does not exist for topic 
partition: " + tp);
+        }
+
+        // Get the respective epoch in which the starting-offset exists.
+        OptionalInt maybeEpoch = 
leaderEpochCache.epochForOffset(startingOffset);
+        while (maybeEpoch.isPresent()) {
+            int epoch = maybeEpoch.getAsInt();
+
+            Iterator<RemoteLogSegmentMetadata> iterator = 
remoteLogMetadataManager.listRemoteLogSegments(new TopicIdPartition(topicId, 
tp), epoch);
+            while (iterator.hasNext()) {
+                RemoteLogSegmentMetadata rlsMetadata = iterator.next();
+                if (rlsMetadata.maxTimestampMs() >= timestamp && 
rlsMetadata.endOffset() >= startingOffset) {
+                    return lookupTimestamp(rlsMetadata, timestamp, 
startingOffset);
+                }
+            }
+
+            // Move to the next epoch if not found with the current epoch.
+            maybeEpoch = leaderEpochCache.nextEpoch(epoch);
+        }
+
+        return Optional.empty();
+    }
+
+    private static abstract class CancellableRunnable implements Runnable {
+        private volatile boolean cancelled = false;
+
+        public void cancel() {
+            cancelled = true;
+        }
+
+        public boolean isCancelled() {
+            return cancelled;
+        }
+    }
+
+    /**
+     * Returns the leader epoch checkpoint by truncating with the given 
start[exclusive] and end[inclusive] offset
+     *
+     * @param log         The actual log from where to take the leader-epoch 
checkpoint
+     * @param startOffset The start offset of the checkpoint file (exclusive 
in the truncation).
+     *                    If start offset is 6, then it will retain an entry 
at offset 6.
+     * @param endOffset   The end offset of the checkpoint file (inclusive in 
the truncation)
+     *                    If end offset is 100, then it will remove the 
entries greater than or equal to 100.
+     * @return the truncated leader epoch checkpoint
+     */
+    InMemoryLeaderEpochCheckpoint getLeaderEpochCheckpoint(UnifiedLog log, 
long startOffset, long endOffset) {
+        InMemoryLeaderEpochCheckpoint checkpoint = new 
InMemoryLeaderEpochCheckpoint();
+        if (log.leaderEpochCache().isDefined()) {
+            LeaderEpochFileCache cache = 
log.leaderEpochCache().get().writeTo(checkpoint);
+            if (startOffset >= 0) {
+                cache.truncateFromStart(startOffset);
+            }
+            cache.truncateFromEnd(endOffset);
+        }
+
+        return checkpoint;
+    }
+
+    class RLMTask extends CancellableRunnable {
+
+        private final TopicIdPartition topicIdPartition;
+        private final Logger logger;
+
+        private volatile int leaderEpoch = -1;
+
+        public RLMTask(TopicIdPartition topicIdPartition) {
+            this.topicIdPartition = topicIdPartition;
+            LogContext logContext = new LogContext("[RemoteLogManager=" + 
brokerId + " partition=" + topicIdPartition + "] ");
+            logger = logContext.logger(RLMTask.class);
+        }
+
+        boolean isLeader() {
+            return leaderEpoch >= 0;
+        }
+
+        // The copiedOffsetOption is OptionalLong.empty() initially for a new 
leader RLMTask, and needs to be fetched inside the task's run() method.
+        private volatile OptionalLong copiedOffsetOption = 
OptionalLong.empty();
+
+        public void convertToLeader(int leaderEpochVal) {
+            if (leaderEpochVal < 0) {
+                throw new KafkaException("leaderEpoch value for topic 
partition " + topicIdPartition + " can not be negative");
+            }
+            if (this.leaderEpoch != leaderEpochVal) {
+                leaderEpoch = leaderEpochVal;
+            }
+            // Reset readOffset, so that it is set in next run of RLMTask
+            copiedOffsetOption = OptionalLong.empty();
+        }
+
+        public void convertToFollower() {
+            leaderEpoch = -1;
+        }
+
+        private void maybeUpdateReadOffset() throws RemoteStorageException {
+            if (!copiedOffsetOption.isPresent()) {
+                logger.info("Find the highest remote offset for partition: {} 
after becoming leader, leaderEpoch: {}", topicIdPartition, leaderEpoch);
+
+                // This is found by traversing from the latest leader epoch 
from leader epoch history and find the highest offset
+                // of a segment with that epoch copied into remote storage. If 
it can not find an entry then it checks for the
+                // previous leader epoch till it finds an entry, If there are 
no entries till the earliest leader epoch in leader
+                // epoch cache then it starts copying the segments from the 
earliest epoch entry’s offset.
+                copiedOffsetOption = 
OptionalLong.of(findHighestRemoteOffset(topicIdPartition));
+            }
+        }
+
+        public void copyLogSegmentsToRemote() throws InterruptedException {
+            if (isCancelled())
+                return;
+
+            try {
+                maybeUpdateReadOffset();
+                long copiedOffset = copiedOffsetOption.getAsLong();
+                Optional<UnifiedLog> maybeLog = 
fetchLog.apply(topicIdPartition.topicPartition());
+                if (!maybeLog.isPresent()) {
+                    return;
+                }
+
+                UnifiedLog log = maybeLog.get();
+
+                // LSO indicates the offset below are ready to be consumed 
(high-watermark or committed)
+                long lso = log.lastStableOffset();
+                if (lso < 0) {
+                    logger.warn("lastStableOffset for partition {} is {}, 
which should not be negative.", topicIdPartition, lso);
+                } else if (lso > 0 && copiedOffset < lso) {
+                    // Copy segments only till the last-stable-offset as 
remote storage should contain only committed/acked
+                    // messages
+                    long toOffset = lso;
+                    logger.debug("Checking for segments to copy, copiedOffset: 
{} and toOffset: {}", copiedOffset, toOffset);
+                    long activeSegBaseOffset = 
log.activeSegment().baseOffset();
+                    // log-start-offset can be ahead of the read-offset, when:
+                    // 1) log-start-offset gets incremented via delete-records 
API (or)
+                    // 2) enabling the remote log for the first time
+                    long fromOffset = Math.max(copiedOffset + 1, 
log.logStartOffset());
+                    ArrayList<LogSegment> sortedSegments = new 
ArrayList<>(JavaConverters.asJavaCollection(log.logSegments(fromOffset, 
toOffset)));
+                    
sortedSegments.sort(Comparator.comparingLong(LogSegment::baseOffset));
+                    List<Long> sortedBaseOffsets = 
sortedSegments.stream().map(x -> x.baseOffset()).collect(Collectors.toList());
+                    int activeSegIndex = 
Collections.binarySearch(sortedBaseOffsets, activeSegBaseOffset);
+
+                    // sortedSegments becomes empty list when fromOffset and 
toOffset are same, and activeSegIndex becomes -1
+                    if (activeSegIndex < 0) {
+                        logger.debug("No segments found to be copied for 
partition {} with copiedOffset: {} and active segment's base-offset: {}",
+                                topicIdPartition, copiedOffset, 
activeSegBaseOffset);
+                    } else {
+                        ListIterator<LogSegment> logSegmentsIter = 
sortedSegments.subList(0, activeSegIndex).listIterator();
+                        while (logSegmentsIter.hasNext()) {
+                            LogSegment segment = logSegmentsIter.next();
+                            if (isCancelled() || !isLeader()) {
+                                logger.info("Skipping copying log segments as 
the current task state is changed, cancelled: {} leader:{}",
+                                        isCancelled(), isLeader());
+                                return;
+                            }
+
+                            copyLogSegment(log, segment, 
getNextSegmentBaseOffset(activeSegBaseOffset, logSegmentsIter));
+                        }
+                    }
+                } else {
+                    logger.debug("Skipping copying segments, current 
read-offset:{}, and LSO:{}", copiedOffset, lso);
+                }
+            } catch (InterruptedException ex) {
+                throw ex;
+            } catch (Exception ex) {
+                if (!isCancelled()) {
+                    logger.error("Error occurred while copying log segments of 
partition: {}", topicIdPartition, ex);
+                }
+            }
+        }
+
+        private long getNextSegmentBaseOffset(long activeSegBaseOffset, 
ListIterator<LogSegment> logSegmentsIter) {
+            long nextSegmentBaseOffset;
+            if (logSegmentsIter.hasNext()) {
+                nextSegmentBaseOffset = logSegmentsIter.next().baseOffset();
+                logSegmentsIter.previous();
+            } else {
+                nextSegmentBaseOffset = activeSegBaseOffset;
+            }
+
+            return nextSegmentBaseOffset;
+        }
+
+        private void copyLogSegment(UnifiedLog log, LogSegment segment, long 
nextSegmentBaseOffset) throws InterruptedException, ExecutionException, 
RemoteStorageException, IOException {
+            File logFile = segment.log().file();
+            String logFileName = logFile.getName();
+
+            logger.info("Copying {} to remote storage.", logFileName);
+            RemoteLogSegmentId id = new RemoteLogSegmentId(topicIdPartition, 
Uuid.randomUuid());
+
+            long endOffset = nextSegmentBaseOffset - 1;
+            File producerStateSnapshotFile = 
log.producerStateManager().fetchSnapshot(nextSegmentBaseOffset).orElse(null);
+
+            List<EpochEntry> epochEntries = getLeaderEpochCheckpoint(log, 
segment.baseOffset(), nextSegmentBaseOffset).read();
+            Map<Integer, Long> segmentLeaderEpochs = new 
HashMap<>(epochEntries.size());
+            epochEntries.forEach(entry -> segmentLeaderEpochs.put(entry.epoch, 
entry.startOffset));
+
+            RemoteLogSegmentMetadata copySegmentStartedRlsm = new 
RemoteLogSegmentMetadata(id, segment.baseOffset(), endOffset,
+                    segment.largestTimestamp(), brokerId, time.milliseconds(), 
segment.log().sizeInBytes(),
+                    segmentLeaderEpochs);
+
+            
remoteLogMetadataManager.addRemoteLogSegmentMetadata(copySegmentStartedRlsm).get();
+
+            ByteBuffer leaderEpochsIndex = getLeaderEpochCheckpoint(log, -1, 
nextSegmentBaseOffset).readAsByteBuffer();
+            LogSegmentData segmentData = new LogSegmentData(logFile.toPath(), 
toPathIfExists(segment.lazyOffsetIndex().get().file()),
+                    toPathIfExists(segment.lazyTimeIndex().get().file()), 
Optional.ofNullable(toPathIfExists(segment.txnIndex().file())),
+                    producerStateSnapshotFile.toPath(), leaderEpochsIndex);
+            remoteLogStorageManager.copyLogSegmentData(copySegmentStartedRlsm, 
segmentData);
+
+            RemoteLogSegmentMetadataUpdate copySegmentFinishedRlsm = new 
RemoteLogSegmentMetadataUpdate(id, time.milliseconds(),
+                    RemoteLogSegmentState.COPY_SEGMENT_FINISHED, brokerId);
+
+            
remoteLogMetadataManager.updateRemoteLogSegmentMetadata(copySegmentFinishedRlsm).get();
+
+            copiedOffsetOption = OptionalLong.of(endOffset);
+            log.updateHighestOffsetInRemoteStorage(endOffset);
+            logger.info("Copied {} to remote storage with segment-id: {}", 
logFileName, copySegmentFinishedRlsm.remoteLogSegmentId());
+        }
+
+        private Path toPathIfExists(File file) {
+            return file.exists() ? file.toPath() : null;
+        }
+
+        public void run() {
+            if (isCancelled())
+                return;
+
+            try {
+                if (isLeader()) {
+                    // Copy log segments to remote storage
+                    copyLogSegmentsToRemote();
+                }
+            } catch (InterruptedException ex) {
+                if (!isCancelled()) {
+                    logger.warn("Current thread for topic-partition-id {} is 
interrupted, this task won't be rescheduled. " +
+                            "Reason: {}", topicIdPartition, ex.getMessage());
+                }
+            } catch (Exception ex) {
+                if (!isCancelled()) {
+                    logger.warn("Current task for topic-partition {} received 
error but it will be scheduled. " +
+                            "Reason: {}", topicIdPartition, ex.getMessage());
+                }
+            }
+        }
+
+        public String toString() {
+            return this.getClass().toString() + "[" + topicIdPartition + "]";
+        }
+    }
+
+    long findHighestRemoteOffset(TopicIdPartition topicIdPartition) throws 
RemoteStorageException {
+        Optional<Long> offset = Optional.empty();
+        Optional<UnifiedLog> maybeLog = 
fetchLog.apply(topicIdPartition.topicPartition());
+        if (maybeLog.isPresent()) {
+            UnifiedLog log = maybeLog.get();
+            Option<LeaderEpochFileCache> maybeLeaderEpochFileCache = 
log.leaderEpochCache();
+            if (maybeLeaderEpochFileCache.isDefined()) {
+                LeaderEpochFileCache cache = maybeLeaderEpochFileCache.get();
+                OptionalInt epoch = cache.latestEpoch();
+                while (!offset.isPresent() && epoch.isPresent()) {
+                    offset = 
remoteLogMetadataManager.highestOffsetForEpoch(topicIdPartition, 
epoch.getAsInt());
+                    epoch = cache.previousEpoch(epoch.getAsInt());
+                }
+            }
+        }
+
+        return offset.orElse(-1L);
+    }
+
+    void doHandleLeaderOrFollowerPartitions(TopicIdPartition topicPartition,
+                                            Consumer<RLMTask> 
convertToLeaderOrFollower) {
+        RLMTaskWithFuture rlmTaskWithFuture = 
leaderOrFollowerTasks.computeIfAbsent(topicPartition,
+                topicIdPartition -> {
+                    RLMTask task = new RLMTask(topicIdPartition);
+                    // set this upfront when it is getting initialized instead 
of doing it after scheduling.
+                    convertToLeaderOrFollower.accept(task);
+                    LOGGER.info("Created a new task: {} and getting 
scheduled", task);
+                    ScheduledFuture future = 
rlmScheduledThreadPool.scheduleWithFixedDelay(task, 0, delayInMs, 
TimeUnit.MILLISECONDS);
+                    return new RLMTaskWithFuture(task, future);
+                }
+        );
+        convertToLeaderOrFollower.accept(rlmTaskWithFuture.rlmTask);
+    }
+
+    static class RLMTaskWithFuture {
+
+        private final RLMTask rlmTask;
+        private final Future<?> future;
+
+        RLMTaskWithFuture(RLMTask rlmTask, Future<?> future) {
+            this.rlmTask = rlmTask;
+            this.future = future;
+        }
+
+        public void cancel() {
+            rlmTask.cancel();
+            try {
+                future.cancel(true);
+            } catch (Exception ex) {
+                LOGGER.error("Error occurred while canceling the task: {}", 
rlmTask, ex);
+            }
+        }
+
+    }
+
+    /**
+     * Closes and releases all the resources like RemoterStorageManager and 
RemoteLogMetadataManager.
+     */
+    public void close() {
+        synchronized (this) {
+            if (!closed) {
+                
leaderOrFollowerTasks.values().forEach(RLMTaskWithFuture::cancel);
+                Utils.closeQuietly(remoteLogStorageManager, 
"RemoteLogStorageManager");
+                Utils.closeQuietly(remoteLogMetadataManager, 
"RemoteLogMetadataManager");
+                Utils.closeQuietly(indexCache, "RemoteIndexCache");
+                try {
+                    rlmScheduledThreadPool.shutdown();
+                } catch (InterruptedException e) {
+                    // ignore
+                }
+                leaderOrFollowerTasks.clear();
+                closed = true;
+            }
+        }
+    }
+
+    static class RLMScheduledThreadPool {
+
+        private static final Logger LOGGER = 
LoggerFactory.getLogger(RLMScheduledThreadPool.class);
+        private final int poolSize;
+        private final ScheduledThreadPoolExecutor scheduledThreadPool;
+
+        public RLMScheduledThreadPool(int poolSize) {
+            this.poolSize = poolSize;
+            scheduledThreadPool = createPool();
+        }
+
+        private ScheduledThreadPoolExecutor createPool() {
+            ScheduledThreadPoolExecutor threadPool = new 
ScheduledThreadPoolExecutor(poolSize);
+            threadPool.setRemoveOnCancelPolicy(true);
+            
threadPool.setExecuteExistingDelayedTasksAfterShutdownPolicy(false);
+            
threadPool.setContinueExistingPeriodicTasksAfterShutdownPolicy(false);
+            threadPool.setThreadFactory(new ThreadFactory() {
+                private final AtomicInteger sequence = new AtomicInteger();
+
+                public Thread newThread(Runnable r) {
+                    return KafkaThread.daemon("kafka-rlm-thread-pool-" + 
sequence.incrementAndGet(), r);
+                }
+            });
+
+            return threadPool;
+        }
+
+        public ScheduledFuture scheduleWithFixedDelay(Runnable runnable, long 
initialDelay, long delay, TimeUnit timeUnit) {
+            LOGGER.info("Scheduling runnable {} with initial delay: {}, fixed 
delay: {}", runnable, initialDelay, delay);
+            return scheduledThreadPool.scheduleWithFixedDelay(runnable, 
initialDelay, delay, timeUnit);
+        }
+
+        public boolean shutdown() throws InterruptedException {
+            LOGGER.info("Shutting down scheduled thread pool");
+            scheduledThreadPool.shutdownNow();
+            //waits for 2 mins to terminate the current tasks
+            return scheduledThreadPool.awaitTermination(2, TimeUnit.MINUTES);
+        }
+    }
+
+}
\ No newline at end of file
diff --git a/core/src/main/scala/kafka/log/UnifiedLog.scala 
b/core/src/main/scala/kafka/log/UnifiedLog.scala
index dbfab74d3d0..5db020d0adf 100644
--- a/core/src/main/scala/kafka/log/UnifiedLog.scala
+++ b/core/src/main/scala/kafka/log/UnifiedLog.scala
@@ -149,6 +149,8 @@ class UnifiedLog(@volatile var logStartOffset: Long,
 
   def localLogStartOffset(): Long = _localLogStartOffset
 
+  @volatile private var highestOffsetInRemoteStorage: Long = -1L
+
   locally {
     initializePartitionMetadata()
     updateLogStartOffset(logStartOffset)
@@ -520,6 +522,11 @@ class UnifiedLog(@volatile var logStartOffset: Long,
       localLog.updateRecoveryPoint(offset)
     }
   }
+  def updateHighestOffsetInRemoteStorage(offset: Long): Unit = {
+    if (!remoteLogEnabled())
+      warn(s"Unable to update the highest offset in remote storage with offset 
$offset since remote storage is not enabled. The existing highest offset is 
$highestOffsetInRemoteStorage.")
+    else if (offset > highestOffsetInRemoteStorage) 
highestOffsetInRemoteStorage = offset
+  }
 
   // Rebuild producer state until lastOffset. This method may be called from 
the recovery code path, and thus must be
   // free of all side-effects, i.e. it must not update any log-specific state.
@@ -1231,10 +1238,10 @@ class UnifiedLog(@volatile var logStartOffset: Long,
           }
 
           remoteLogManager.get.findOffsetByTimestamp(topicPartition, 
targetTimestamp, logStartOffset, leaderEpochCache.get)
-        } else None
+        } else Optional.empty()
 
-        if (remoteOffset.nonEmpty) {
-          remoteOffset
+        if (remoteOffset.isPresent) {
+          remoteOffset.asScala
         } else {
           // If it is not found in remote storage, search in the local storage 
starting with local log start offset.
 
diff --git a/core/src/main/scala/kafka/log/remote/RemoteLogManager.scala 
b/core/src/main/scala/kafka/log/remote/RemoteLogManager.scala
deleted file mode 100644
index a0fa0058000..00000000000
--- a/core/src/main/scala/kafka/log/remote/RemoteLogManager.scala
+++ /dev/null
@@ -1,289 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package kafka.log.remote
-
-import kafka.cluster.Partition
-import kafka.server.KafkaConfig
-import kafka.utils.Logging
-import org.apache.kafka.common._
-import org.apache.kafka.common.record.FileRecords.TimestampAndOffset
-import org.apache.kafka.common.record.{RecordBatch, RemoteLogInputStream}
-import org.apache.kafka.common.utils.{ChildFirstClassLoader, Utils}
-import 
org.apache.kafka.server.log.remote.metadata.storage.ClassLoaderAwareRemoteLogMetadataManager
-import org.apache.kafka.server.log.remote.storage._
-import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache
-
-import java.io.{Closeable, InputStream}
-import java.security.{AccessController, PrivilegedAction}
-import java.util
-import java.util.Optional
-import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
-import scala.collection.Set
-import scala.jdk.CollectionConverters._
-
-/**
- * This class is responsible for
- *  - initializing `RemoteStorageManager` and `RemoteLogMetadataManager` 
instances.
- *  - receives any leader and follower replica events and partition stop 
events and act on them
- *  - also provides APIs to fetch indexes, metadata about remote log segments.
- *
- * @param rlmConfig Configuration required for remote logging subsystem(tiered 
storage) at the broker level.
- * @param brokerId  id of the current broker.
- * @param logDir    directory of Kafka log segments.
- */
-class RemoteLogManager(rlmConfig: RemoteLogManagerConfig,
-                       brokerId: Int,
-                       logDir: String) extends Logging with Closeable {
-
-  // topic ids received on leadership changes
-  private val topicPartitionIds: ConcurrentMap[TopicPartition, Uuid] = new 
ConcurrentHashMap[TopicPartition, Uuid]()
-
-  private val remoteLogStorageManager: RemoteStorageManager = 
createRemoteStorageManager()
-  private val remoteLogMetadataManager: RemoteLogMetadataManager = 
createRemoteLogMetadataManager()
-
-  private val indexCache = new RemoteIndexCache(remoteStorageManager = 
remoteLogStorageManager, logDir = logDir)
-
-  private var closed = false
-
-  private[remote] def createRemoteStorageManager(): RemoteStorageManager = {
-    def createDelegate(classLoader: ClassLoader): RemoteStorageManager = {
-      classLoader.loadClass(rlmConfig.remoteStorageManagerClassName())
-        
.getDeclaredConstructor().newInstance().asInstanceOf[RemoteStorageManager]
-    }
-
-    AccessController.doPrivileged(new PrivilegedAction[RemoteStorageManager] {
-      private val classPath = rlmConfig.remoteStorageManagerClassPath()
-
-      override def run(): RemoteStorageManager = {
-          if (classPath != null && classPath.trim.nonEmpty) {
-            val classLoader = new ChildFirstClassLoader(classPath, 
this.getClass.getClassLoader)
-            val delegate = createDelegate(classLoader)
-            new ClassLoaderAwareRemoteStorageManager(delegate, classLoader)
-          } else {
-            createDelegate(this.getClass.getClassLoader)
-          }
-      }
-    })
-  }
-
-  private def configureRSM(): Unit = {
-    val rsmProps = new util.HashMap[String, Any]()
-    rlmConfig.remoteStorageManagerProps().asScala.foreach { case (k, v) => 
rsmProps.put(k, v) }
-    rsmProps.put(KafkaConfig.BrokerIdProp, brokerId)
-    remoteLogStorageManager.configure(rsmProps)
-  }
-
-  private[remote] def createRemoteLogMetadataManager(): 
RemoteLogMetadataManager = {
-    def createDelegate(classLoader: ClassLoader) = {
-      classLoader.loadClass(rlmConfig.remoteLogMetadataManagerClassName())
-        .getDeclaredConstructor()
-        .newInstance()
-        .asInstanceOf[RemoteLogMetadataManager]
-    }
-
-    AccessController.doPrivileged(new 
PrivilegedAction[RemoteLogMetadataManager] {
-      private val classPath = rlmConfig.remoteLogMetadataManagerClassPath
-
-      override def run(): RemoteLogMetadataManager = {
-        if (classPath != null && classPath.trim.nonEmpty) {
-          val classLoader = new ChildFirstClassLoader(classPath, 
this.getClass.getClassLoader)
-          val delegate = createDelegate(classLoader)
-          new ClassLoaderAwareRemoteLogMetadataManager(delegate, classLoader)
-        } else {
-          createDelegate(this.getClass.getClassLoader)
-        }
-      }
-    })
-  }
-
-  private def configureRLMM(): Unit = {
-    val rlmmProps = new util.HashMap[String, Any]()
-    rlmConfig.remoteLogMetadataManagerProps().asScala.foreach { case (k, v) => 
rlmmProps.put(k, v) }
-    rlmmProps.put(KafkaConfig.BrokerIdProp, brokerId)
-    rlmmProps.put(KafkaConfig.LogDirProp, logDir)
-    remoteLogMetadataManager.configure(rlmmProps)
-  }
-
-  def startup(): Unit = {
-    // Initialize and configure RSM and RLMM. This will start RSM, RLMM 
resources which may need to start resources
-    // in connecting to the brokers or remote storages.
-    configureRSM()
-    configureRLMM()
-  }
-
-  def storageManager(): RemoteStorageManager = {
-    remoteLogStorageManager
-  }
-
-  /**
-   * Callback to receive any leadership changes for the topic partitions 
assigned to this broker. If there are no
-   * existing tasks for a given topic partition then it will assign new leader 
or follower task else it will convert the
-   * task to respective target state(leader or follower).
-   *
-   * @param partitionsBecomeLeader   partitions that have become leaders on 
this broker.
-   * @param partitionsBecomeFollower partitions that have become followers on 
this broker.
-   * @param topicIds                 topic name to topic id mappings.
-   */
-  def onLeadershipChange(partitionsBecomeLeader: Set[Partition],
-                         partitionsBecomeFollower: Set[Partition],
-                         topicIds: util.Map[String, Uuid]): Unit = {
-    debug(s"Received leadership changes for leaders: $partitionsBecomeLeader 
and followers: $partitionsBecomeFollower")
-
-    // Partitions logs are available when this callback is invoked.
-    // Compact topics and internal topics are filtered here as they are not 
supported with tiered storage.
-    def filterPartitions(partitions: Set[Partition]): Set[TopicIdPartition] = {
-      // We are not specifically checking for internal topics etc here as 
`log.remoteLogEnabled()` already handles that.
-      partitions.filter(partition => partition.log.exists(log => 
log.remoteLogEnabled()))
-        .map(partition => new TopicIdPartition(topicIds.get(partition.topic), 
partition.topicPartition))
-    }
-
-    val followerTopicPartitions = filterPartitions(partitionsBecomeFollower)
-    val leaderTopicPartitions = filterPartitions(partitionsBecomeLeader)
-    debug(s"Effective topic partitions after filtering compact and internal 
topics, leaders: $leaderTopicPartitions " +
-      s"and followers: $followerTopicPartitions")
-
-    if (leaderTopicPartitions.nonEmpty || followerTopicPartitions.nonEmpty) {
-      leaderTopicPartitions.foreach(x => 
topicPartitionIds.put(x.topicPartition(), x.topicId()))
-      followerTopicPartitions.foreach(x => 
topicPartitionIds.put(x.topicPartition(), x.topicId()))
-
-      
remoteLogMetadataManager.onPartitionLeadershipChanges(leaderTopicPartitions.asJava,
 followerTopicPartitions.asJava)
-    }
-  }
-
-  /**
-   * Deletes the internal topic partition info if delete flag is set as true.
-   *
-   * @param topicPartition topic partition to be stopped.
-   * @param delete         flag to indicate whether the given topic partitions 
to be deleted or not.
-   */
-  def stopPartitions(topicPartition: TopicPartition, delete: Boolean): Unit = {
-    if (delete) {
-      // Delete from internal datastructures only if it is to be deleted.
-      val topicIdPartition = topicPartitionIds.remove(topicPartition)
-      debug(s"Removed partition: $topicIdPartition from topicPartitionIds")
-    }
-  }
-
-  def fetchRemoteLogSegmentMetadata(topicPartition: TopicPartition,
-                                    epochForOffset: Int,
-                                    offset: Long): 
Optional[RemoteLogSegmentMetadata] = {
-    val topicId = topicPartitionIds.get(topicPartition)
-
-    if (topicId == null) {
-      throw new KafkaException("No topic id registered for topic partition: " 
+ topicPartition)
-    }
-
-    remoteLogMetadataManager.remoteLogSegmentMetadata(new 
TopicIdPartition(topicId, topicPartition), epochForOffset, offset)
-  }
-
-  private def lookupTimestamp(rlsMetadata: RemoteLogSegmentMetadata, 
timestamp: Long, startingOffset: Long): Option[TimestampAndOffset] = {
-    val startPos = indexCache.lookupTimestamp(rlsMetadata, timestamp, 
startingOffset)
-
-    var remoteSegInputStream: InputStream = null
-    try {
-      // Search forward for the position of the last offset that is greater 
than or equal to the startingOffset
-      remoteSegInputStream = 
remoteLogStorageManager.fetchLogSegment(rlsMetadata, startPos)
-      val remoteLogInputStream = new RemoteLogInputStream(remoteSegInputStream)
-      var batch: RecordBatch = null
-
-      def nextBatch(): RecordBatch = {
-        batch = remoteLogInputStream.nextBatch()
-        batch
-      }
-
-      while (nextBatch() != null) {
-        if (batch.maxTimestamp >= timestamp && batch.lastOffset >= 
startingOffset) {
-          batch.iterator.asScala.foreach(record => {
-            if (record.timestamp >= timestamp && record.offset >= 
startingOffset)
-              return Some(new TimestampAndOffset(record.timestamp, 
record.offset, maybeLeaderEpoch(batch.partitionLeaderEpoch)))
-          })
-        }
-      }
-      None
-    } finally {
-      Utils.closeQuietly(remoteSegInputStream, "RemoteLogSegmentInputStream")
-    }
-  }
-
-  private def maybeLeaderEpoch(leaderEpoch: Int): Optional[Integer] = {
-    if (leaderEpoch == RecordBatch.NO_PARTITION_LEADER_EPOCH)
-      Optional.empty()
-    else
-      Optional.of(leaderEpoch)
-  }
-
-  /**
-   * Search the message offset in the remote storage based on timestamp and 
offset.
-   *
-   * This method returns an option of TimestampOffset. The returned value is 
determined using the following ordered list of rules:
-   *
-   * - If there are no messages in the remote storage, return None
-   * - If all the messages in the remote storage have smaller offsets, return 
None
-   * - If all the messages in the remote storage have smaller timestamps, 
return None
-   * - Otherwise, return an option of TimestampOffset. The offset is the 
offset of the first message whose timestamp
-   * is greater than or equals to the target timestamp and whose offset is 
greater than or equals to the startingOffset.
-   *
-   * @param tp               topic partition in which the offset to be found.
-   * @param timestamp        The timestamp to search for.
-   * @param startingOffset   The starting offset to search.
-   * @param leaderEpochCache LeaderEpochFileCache of the topic partition.
-   * @return the timestamp and offset of the first message that meets the 
requirements. None will be returned if there
-   *         is no such message.
-   */
-  def findOffsetByTimestamp(tp: TopicPartition,
-                            timestamp: Long,
-                            startingOffset: Long,
-                            leaderEpochCache: LeaderEpochFileCache): 
Option[TimestampAndOffset] = {
-    val topicId = topicPartitionIds.get(tp)
-    if (topicId == null) {
-      throw new KafkaException("Topic id does not exist for topic partition: " 
+ tp)
-    }
-
-    // Get the respective epoch in which the starting-offset exists.
-    var maybeEpoch = leaderEpochCache.epochForOffset(startingOffset)
-    while (maybeEpoch.isPresent) {
-      val epoch = maybeEpoch.getAsInt
-      remoteLogMetadataManager.listRemoteLogSegments(new 
TopicIdPartition(topicId, tp), epoch).asScala
-        .foreach(rlsMetadata =>
-          if (rlsMetadata.maxTimestampMs() >= timestamp && 
rlsMetadata.endOffset() >= startingOffset) {
-            val timestampOffset = lookupTimestamp(rlsMetadata, timestamp, 
startingOffset)
-            if (timestampOffset.isDefined)
-              return timestampOffset
-          }
-        )
-
-      // Move to the next epoch if not found with the current epoch.
-      maybeEpoch = leaderEpochCache.nextEpoch(epoch)
-    }
-    None
-  }
-
-  /**
-   * Closes and releases all the resources like RemoterStorageManager and 
RemoteLogMetadataManager.
-   */
-  def close(): Unit = {
-    this synchronized {
-      if (!closed) {
-        Utils.closeQuietly(remoteLogStorageManager, "RemoteLogStorageManager")
-        Utils.closeQuietly(remoteLogMetadataManager, 
"RemoteLogMetadataManager")
-        Utils.closeQuietly(indexCache, "RemoteIndexCache")
-        closed = true
-      }
-    }
-  }
-
-}
\ No newline at end of file
diff --git a/core/src/main/scala/kafka/server/BrokerServer.scala 
b/core/src/main/scala/kafka/server/BrokerServer.scala
index e3a657cd982..605455b8273 100644
--- a/core/src/main/scala/kafka/server/BrokerServer.scala
+++ b/core/src/main/scala/kafka/server/BrokerServer.scala
@@ -35,7 +35,7 @@ import org.apache.kafka.common.security.auth.SecurityProtocol
 import org.apache.kafka.common.security.scram.internals.ScramMechanism
 import 
org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache
 import org.apache.kafka.common.utils.{LogContext, Time, Utils}
-import org.apache.kafka.common.{ClusterResource, Endpoint, KafkaException}
+import org.apache.kafka.common.{ClusterResource, Endpoint, KafkaException, 
TopicPartition}
 import org.apache.kafka.coordinator.group.GroupCoordinator
 import org.apache.kafka.image.publisher.MetadataPublisher
 import org.apache.kafka.metadata.authorizer.ClusterMetadataAuthorizer
@@ -54,6 +54,7 @@ import java.util.concurrent.atomic.AtomicBoolean
 import java.util.concurrent.locks.ReentrantLock
 import java.util.concurrent.{CompletableFuture, ExecutionException, TimeUnit, 
TimeoutException}
 import scala.collection.{Map, Seq}
+import scala.compat.java8.OptionConverters.RichOptionForJava8
 import scala.jdk.CollectionConverters._
 
 
@@ -513,7 +514,8 @@ class BrokerServer(
         throw new KafkaException("Tiered storage is not supported with 
multiple log dirs.");
       }
 
-      Some(new RemoteLogManager(remoteLogManagerConfig, config.brokerId, 
config.logDirs.head))
+      Some(new RemoteLogManager(remoteLogManagerConfig, config.brokerId, 
config.logDirs.head, time,
+        (tp: TopicPartition) => logManager.getLog(tp).asJava));
     } else {
       None
     }
diff --git a/core/src/main/scala/kafka/server/KafkaServer.scala 
b/core/src/main/scala/kafka/server/KafkaServer.scala
index 5e5995e5a66..d2fbb607324 100755
--- a/core/src/main/scala/kafka/server/KafkaServer.scala
+++ b/core/src/main/scala/kafka/server/KafkaServer.scala
@@ -17,18 +17,14 @@
 
 package kafka.server
 
-import java.io.{File, IOException}
-import java.net.{InetAddress, SocketTimeoutException}
-import java.util.concurrent._
-import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger}
 import kafka.cluster.{Broker, EndPoint}
 import kafka.common.{GenerateBrokerIdException, InconsistentBrokerIdException, 
InconsistentClusterIdException}
 import kafka.controller.KafkaController
 import kafka.coordinator.group.GroupCoordinatorAdapter
 import kafka.coordinator.transaction.{ProducerIdManager, 
TransactionCoordinator}
 import kafka.log.LogManager
-import kafka.metrics.KafkaMetricsReporter
 import kafka.log.remote.RemoteLogManager
+import kafka.metrics.KafkaMetricsReporter
 import kafka.network.{ControlPlaneAcceptor, DataPlaneAcceptor, RequestChannel, 
SocketServer}
 import kafka.raft.KafkaRaftManager
 import kafka.security.CredentialProvider
@@ -48,21 +44,26 @@ import 
org.apache.kafka.common.security.scram.internals.ScramMechanism
 import 
org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache
 import org.apache.kafka.common.security.{JaasContext, JaasUtils}
 import org.apache.kafka.common.utils.{AppInfoParser, LogContext, Time, Utils}
-import org.apache.kafka.common.{Endpoint, KafkaException, Node}
+import org.apache.kafka.common.{Endpoint, KafkaException, Node, TopicPartition}
 import org.apache.kafka.coordinator.group.GroupCoordinator
 import org.apache.kafka.metadata.{BrokerState, MetadataRecordSerde, 
VersionRange}
 import org.apache.kafka.raft.RaftConfig
 import org.apache.kafka.server.authorizer.Authorizer
-import org.apache.kafka.server.common.{ApiMessageAndVersion, MetadataVersion}
 import org.apache.kafka.server.common.MetadataVersion._
+import org.apache.kafka.server.common.{ApiMessageAndVersion, MetadataVersion}
 import org.apache.kafka.server.fault.LoggingFaultHandler
-import org.apache.kafka.server.metrics.KafkaYammerMetrics
 import org.apache.kafka.server.log.remote.storage.RemoteLogManagerConfig
+import org.apache.kafka.server.metrics.KafkaYammerMetrics
 import org.apache.kafka.server.util.KafkaScheduler
 import org.apache.kafka.storage.internals.log.LogDirFailureChannel
 import org.apache.zookeeper.client.ZKClientConfig
 
+import java.io.{File, IOException}
+import java.net.{InetAddress, SocketTimeoutException}
+import java.util.concurrent._
+import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger}
 import scala.collection.{Map, Seq}
+import scala.compat.java8.OptionConverters.RichOptionForJava8
 import scala.jdk.CollectionConverters._
 
 object KafkaServer {
@@ -602,7 +603,8 @@ class KafkaServer(
         throw new KafkaException("Tiered storage is not supported with 
multiple log dirs.");
       }
 
-      Some(new RemoteLogManager(remoteLogManagerConfig, config.brokerId, 
config.logDirs.head))
+      Some(new RemoteLogManager(remoteLogManagerConfig, config.brokerId, 
config.logDirs.head, time,
+        (tp: TopicPartition) => logManager.getLog(tp).asJava));
     } else {
       None
     }
diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala 
b/core/src/main/scala/kafka/server/ReplicaManager.scala
index 7a138b1e9f2..0039611c1cb 100644
--- a/core/src/main/scala/kafka/server/ReplicaManager.scala
+++ b/core/src/main/scala/kafka/server/ReplicaManager.scala
@@ -1484,6 +1484,9 @@ class ReplicaManager(val config: KafkaConfig,
 
           replicaFetcherManager.shutdownIdleFetcherThreads()
           replicaAlterLogDirsManager.shutdownIdleFetcherThreads()
+
+          remoteLogManager.foreach(rlm => 
rlm.onLeadershipChange(partitionsBecomeLeader.asJava, 
partitionsBecomeFollower.asJava, topicIds))
+
           onLeadershipChange(partitionsBecomeLeader, partitionsBecomeFollower)
 
           val data = new 
LeaderAndIsrResponseData().setErrorCode(Errors.NONE.code)
diff --git a/core/src/test/java/kafka/log/remote/RemoteLogManagerTest.java 
b/core/src/test/java/kafka/log/remote/RemoteLogManagerTest.java
new file mode 100644
index 00000000000..e56fbd45a08
--- /dev/null
+++ b/core/src/test/java/kafka/log/remote/RemoteLogManagerTest.java
@@ -0,0 +1,573 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package kafka.log.remote;
+
+import kafka.cluster.Partition;
+import kafka.log.LogSegment;
+import kafka.log.UnifiedLog;
+import org.apache.kafka.common.KafkaException;
+import org.apache.kafka.common.TopicIdPartition;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.Uuid;
+import org.apache.kafka.common.config.AbstractConfig;
+import org.apache.kafka.common.record.CompressionType;
+import org.apache.kafka.common.record.FileRecords;
+import org.apache.kafka.common.record.MemoryRecords;
+import org.apache.kafka.common.record.SimpleRecord;
+import org.apache.kafka.common.utils.MockTime;
+import org.apache.kafka.common.utils.Time;
+import 
org.apache.kafka.server.log.remote.storage.ClassLoaderAwareRemoteStorageManager;
+import org.apache.kafka.server.log.remote.storage.LogSegmentData;
+import org.apache.kafka.server.log.remote.storage.NoOpRemoteLogMetadataManager;
+import org.apache.kafka.server.log.remote.storage.NoOpRemoteStorageManager;
+import org.apache.kafka.server.log.remote.storage.RemoteLogManagerConfig;
+import org.apache.kafka.server.log.remote.storage.RemoteLogMetadataManager;
+import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentId;
+import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadata;
+import 
org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadataUpdate;
+import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentState;
+import org.apache.kafka.server.log.remote.storage.RemoteStorageException;
+import org.apache.kafka.server.log.remote.storage.RemoteStorageManager;
+import 
org.apache.kafka.server.log.remote.storage.RemoteStorageManager.IndexType;
+import 
org.apache.kafka.storage.internals.checkpoint.InMemoryLeaderEpochCheckpoint;
+import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpoint;
+import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache;
+import org.apache.kafka.storage.internals.log.EpochEntry;
+import org.apache.kafka.storage.internals.log.LazyIndex;
+import org.apache.kafka.storage.internals.log.OffsetIndex;
+import org.apache.kafka.storage.internals.log.ProducerStateManager;
+import org.apache.kafka.storage.internals.log.TimeIndex;
+import org.apache.kafka.storage.internals.log.TransactionIndex;
+import org.apache.kafka.test.TestUtils;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
+import org.mockito.InOrder;
+import org.mockito.Mockito;
+import scala.Option;
+import scala.collection.JavaConverters;
+
+import java.io.ByteArrayInputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.NavigableMap;
+import java.util.Optional;
+import java.util.Properties;
+import java.util.TreeMap;
+import java.util.concurrent.CompletableFuture;
+
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.inOrder;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+public class RemoteLogManagerTest {
+    Time time = new MockTime();
+    int brokerId = 0;
+    String logDir = TestUtils.tempDirectory("kafka-").toString();
+
+    RemoteStorageManager remoteStorageManager = 
mock(RemoteStorageManager.class);
+    RemoteLogMetadataManager remoteLogMetadataManager = 
mock(RemoteLogMetadataManager.class);
+    RemoteLogManagerConfig remoteLogManagerConfig = null;
+    RemoteLogManager remoteLogManager = null;
+
+    TopicIdPartition leaderTopicIdPartition = new 
TopicIdPartition(Uuid.randomUuid(), new TopicPartition("Leader", 0));
+    TopicIdPartition followerTopicIdPartition = new 
TopicIdPartition(Uuid.randomUuid(), new TopicPartition("Follower", 0));
+    Map<String, Uuid> topicIds = new HashMap<>();
+    TopicPartition tp = new TopicPartition("TestTopic", 5);
+    EpochEntry epochEntry0 = new EpochEntry(0, 0);
+    EpochEntry epochEntry1 = new EpochEntry(1, 100);
+    EpochEntry epochEntry2 = new EpochEntry(2, 200);
+    List<EpochEntry> totalEpochEntries = Arrays.asList(epochEntry0, 
epochEntry1, epochEntry2);
+    LeaderEpochCheckpoint checkpoint = new LeaderEpochCheckpoint() {
+        List<EpochEntry> epochs = Collections.emptyList();
+        @Override
+        public void write(Collection<EpochEntry> epochs) {
+            this.epochs = new ArrayList<>(epochs);
+        }
+
+        @Override
+        public List<EpochEntry> read() {
+            return epochs;
+        }
+    };
+
+    UnifiedLog mockLog = mock(UnifiedLog.class);
+
+    @BeforeEach
+    void setUp() throws Exception {
+        topicIds.put(leaderTopicIdPartition.topicPartition().topic(), 
leaderTopicIdPartition.topicId());
+        topicIds.put(followerTopicIdPartition.topicPartition().topic(), 
followerTopicIdPartition.topicId());
+        Properties props = new Properties();
+        remoteLogManagerConfig = createRLMConfig(props);
+        remoteLogManager = new RemoteLogManager(remoteLogManagerConfig, 
brokerId, logDir, time, tp -> Optional.of(mockLog)) {
+            public RemoteStorageManager createRemoteStorageManager() {
+                return remoteStorageManager;
+            }
+            public RemoteLogMetadataManager createRemoteLogMetadataManager() {
+                return remoteLogMetadataManager;
+            }
+        };
+    }
+
+    @Test
+    void testGetLeaderEpochCheckpoint() {
+        checkpoint.write(totalEpochEntries);
+        LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint);
+        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        InMemoryLeaderEpochCheckpoint inMemoryCheckpoint = 
remoteLogManager.getLeaderEpochCheckpoint(mockLog, 0, 300);
+        assertEquals(totalEpochEntries, inMemoryCheckpoint.read());
+
+        InMemoryLeaderEpochCheckpoint inMemoryCheckpoint2 = 
remoteLogManager.getLeaderEpochCheckpoint(mockLog, 100, 200);
+        List<EpochEntry> epochEntries = inMemoryCheckpoint2.read();
+        assertEquals(1, epochEntries.size());
+        assertEquals(epochEntry1, epochEntries.get(0));
+    }
+
+    @Test
+    void testFindHighestRemoteOffset() throws RemoteStorageException {
+        checkpoint.write(totalEpochEntries);
+        LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint);
+        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        TopicIdPartition tpId = new TopicIdPartition(Uuid.randomUuid(), tp);
+        long offset = remoteLogManager.findHighestRemoteOffset(tpId);
+        assertEquals(-1, offset);
+
+        when(remoteLogMetadataManager.highestOffsetForEpoch(tpId, 
2)).thenReturn(Optional.of(200L));
+        long offset2 = remoteLogManager.findHighestRemoteOffset(tpId);
+        assertEquals(200, offset2);
+    }
+
+    @Test
+    void testRemoteLogMetadataManagerWithUserDefinedConfigs() {
+        String key = "key";
+        String configPrefix = "config.prefix";
+        Properties props = new Properties();
+        
props.put(RemoteLogManagerConfig.REMOTE_LOG_METADATA_MANAGER_CONFIG_PREFIX_PROP,
 configPrefix);
+        props.put(configPrefix + key, "world");
+        props.put("remote.log.metadata.y", "z");
+
+        Map<String, Object> metadataMangerConfig = 
createRLMConfig(props).remoteLogMetadataManagerProps();
+        assertEquals(props.get(configPrefix + key), 
metadataMangerConfig.get(key));
+        assertFalse(metadataMangerConfig.containsKey("remote.log.metadata.y"));
+    }
+
+    @Test
+    void testStartup() {
+        remoteLogManager.startup();
+        ArgumentCaptor<Map<String, Object>> capture = 
ArgumentCaptor.forClass(Map.class);
+        verify(remoteStorageManager, times(1)).configure(capture.capture());
+        assertEquals(brokerId, capture.getValue().get("broker.id"));
+
+        verify(remoteLogMetadataManager, 
times(1)).configure(capture.capture());
+        assertEquals(brokerId, capture.getValue().get("broker.id"));
+        assertEquals(logDir, capture.getValue().get("log.dir"));
+    }
+
+    // This test creates 2 log segments, 1st one has start offset of 0, 2nd 
one (and active one) has start offset of 150.
+    // The leader epochs are [0->0, 1->100, 2->200]. We are verifying:
+    // 1. There's only 1 segment copied to remote storage
+    // 2. The segment got copied to remote storage is the old segment, not the 
active one
+    // 3. The log segment metadata stored into remoteLogMetadataManager is 
what we expected, both before and after copying the log segments
+    // 4. The log segment got copied to remote storage has the expected 
metadata
+    // 5. The highest remote offset is updated to the expected value
+    @Test
+    void testCopyLogSegmentsToRemoteShouldCopyExpectedLogSegment() throws 
Exception {
+        long oldSegmentStartOffset = 0L;
+        long nextSegmentStartOffset = 150L;
+        long oldSegmentEndOffset = nextSegmentStartOffset - 1;
+
+        // leader epoch preparation
+        checkpoint.write(totalEpochEntries);
+        LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint);
+        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        
when(remoteLogMetadataManager.highestOffsetForEpoch(any(TopicIdPartition.class),
 anyInt())).thenReturn(Optional.of(0L));
+
+        File tempFile = TestUtils.tempFile();
+        File mockProducerSnapshotIndex = TestUtils.tempFile();
+        File tempDir = TestUtils.tempDirectory();
+        // create 2 log segments, with 0 and 150 as log start offset
+        LogSegment oldSegment = mock(LogSegment.class);
+        LogSegment activeSegment = mock(LogSegment.class);
+
+        when(oldSegment.baseOffset()).thenReturn(oldSegmentStartOffset);
+        when(activeSegment.baseOffset()).thenReturn(nextSegmentStartOffset);
+
+        FileRecords fileRecords = mock(FileRecords.class);
+        when(oldSegment.log()).thenReturn(fileRecords);
+        when(fileRecords.file()).thenReturn(tempFile);
+        when(oldSegment.readNextOffset()).thenReturn(nextSegmentStartOffset);
+
+        when(mockLog.activeSegment()).thenReturn(activeSegment);
+        when(mockLog.logStartOffset()).thenReturn(oldSegmentStartOffset);
+        when(mockLog.logSegments(anyLong(), 
anyLong())).thenReturn(JavaConverters.collectionAsScalaIterable(Arrays.asList(oldSegment,
 activeSegment)));
+
+        ProducerStateManager mockStateManager = 
mock(ProducerStateManager.class);
+        when(mockLog.producerStateManager()).thenReturn(mockStateManager);
+        
when(mockStateManager.fetchSnapshot(anyLong())).thenReturn(Optional.of(mockProducerSnapshotIndex));
+        when(mockLog.lastStableOffset()).thenReturn(250L);
+
+        LazyIndex idx = 
LazyIndex.forOffset(UnifiedLog.offsetIndexFile(tempDir, oldSegmentStartOffset, 
""), oldSegmentStartOffset, 1000);
+        LazyIndex timeIdx = 
LazyIndex.forTime(UnifiedLog.timeIndexFile(tempDir, oldSegmentStartOffset, ""), 
oldSegmentStartOffset, 1500);
+        File txnFile = UnifiedLog.transactionIndexFile(tempDir, 
oldSegmentStartOffset, "");
+        txnFile.createNewFile();
+        TransactionIndex txnIndex = new 
TransactionIndex(oldSegmentStartOffset, txnFile);
+        when(oldSegment.lazyTimeIndex()).thenReturn(timeIdx);
+        when(oldSegment.lazyOffsetIndex()).thenReturn(idx);
+        when(oldSegment.txnIndex()).thenReturn(txnIndex);
+
+        CompletableFuture<Void> dummyFuture = new CompletableFuture<>();
+        dummyFuture.complete(null);
+        
when(remoteLogMetadataManager.addRemoteLogSegmentMetadata(any(RemoteLogSegmentMetadata.class))).thenReturn(dummyFuture);
+        
when(remoteLogMetadataManager.updateRemoteLogSegmentMetadata(any(RemoteLogSegmentMetadataUpdate.class))).thenReturn(dummyFuture);
+        
doNothing().when(remoteStorageManager).copyLogSegmentData(any(RemoteLogSegmentMetadata.class),
 any(LogSegmentData.class));
+
+        RemoteLogManager.RLMTask task = remoteLogManager.new 
RLMTask(leaderTopicIdPartition);
+        task.convertToLeader(2);
+        task.copyLogSegmentsToRemote();
+
+        // verify remoteLogMetadataManager did add the expected 
RemoteLogSegmentMetadata
+        ArgumentCaptor<RemoteLogSegmentMetadata> remoteLogSegmentMetadataArg = 
ArgumentCaptor.forClass(RemoteLogSegmentMetadata.class);
+        
verify(remoteLogMetadataManager).addRemoteLogSegmentMetadata(remoteLogSegmentMetadataArg.capture());
+        // The old segment should only contain leader epoch [0->0, 1->100] 
since its offset range is [0, 149]
+        Map<Integer, Long> expectedLeaderEpochs = new TreeMap<>();
+        expectedLeaderEpochs.put(epochEntry0.epoch, epochEntry0.startOffset);
+        expectedLeaderEpochs.put(epochEntry1.epoch, epochEntry1.startOffset);
+        verifyRemoteLogSegmentMetadata(remoteLogSegmentMetadataArg.getValue(), 
oldSegmentStartOffset, oldSegmentEndOffset, expectedLeaderEpochs);
+
+        // verify copyLogSegmentData is passing the RemoteLogSegmentMetadata 
we created above
+        // and verify the logSegmentData passed is expected
+        ArgumentCaptor<RemoteLogSegmentMetadata> remoteLogSegmentMetadataArg2 
= ArgumentCaptor.forClass(RemoteLogSegmentMetadata.class);
+        ArgumentCaptor<LogSegmentData> logSegmentDataArg = 
ArgumentCaptor.forClass(LogSegmentData.class);
+        verify(remoteStorageManager, 
times(1)).copyLogSegmentData(remoteLogSegmentMetadataArg2.capture(), 
logSegmentDataArg.capture());
+        assertEquals(remoteLogSegmentMetadataArg.getValue(), 
remoteLogSegmentMetadataArg2.getValue());
+        // The old segment should only contain leader epoch [0->0, 1->100] 
since its offset range is [0, 149]
+        verifyLogSegmentData(logSegmentDataArg.getValue(), idx, timeIdx, 
txnIndex, tempFile, mockProducerSnapshotIndex,
+            Arrays.asList(epochEntry0, epochEntry1));
+
+        // verify remoteLogMetadataManager did add the expected 
RemoteLogSegmentMetadataUpdate
+        ArgumentCaptor<RemoteLogSegmentMetadataUpdate> 
remoteLogSegmentMetadataUpdateArg = 
ArgumentCaptor.forClass(RemoteLogSegmentMetadataUpdate.class);
+        verify(remoteLogMetadataManager, 
times(1)).updateRemoteLogSegmentMetadata(remoteLogSegmentMetadataUpdateArg.capture());
+        
verifyRemoteLogSegmentMetadataUpdate(remoteLogSegmentMetadataUpdateArg.getValue());
+
+        // verify the highest remote offset is updated to the expected value
+        ArgumentCaptor<Long> argument = ArgumentCaptor.forClass(Long.class);
+        verify(mockLog, 
times(1)).updateHighestOffsetInRemoteStorage(argument.capture());
+        assertEquals(oldSegmentEndOffset, argument.getValue());
+    }
+
+    @Test
+    void testCopyLogSegmentsToRemoteShouldNotCopySegmentForFollower() throws 
Exception {
+        long oldSegmentStartOffset = 0L;
+        long nextSegmentStartOffset = 150L;
+
+        // leader epoch preparation
+        checkpoint.write(totalEpochEntries);
+        LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint);
+        when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
+        
when(remoteLogMetadataManager.highestOffsetForEpoch(any(TopicIdPartition.class),
 anyInt())).thenReturn(Optional.of(0L));
+
+        // create 2 log segments, with 0 and 150 as log start offset
+        LogSegment oldSegment = mock(LogSegment.class);
+        LogSegment activeSegment = mock(LogSegment.class);
+
+        when(oldSegment.baseOffset()).thenReturn(oldSegmentStartOffset);
+        when(activeSegment.baseOffset()).thenReturn(nextSegmentStartOffset);
+
+        when(mockLog.activeSegment()).thenReturn(activeSegment);
+        when(mockLog.logStartOffset()).thenReturn(oldSegmentStartOffset);
+        when(mockLog.logSegments(anyLong(), 
anyLong())).thenReturn(JavaConverters.collectionAsScalaIterable(Arrays.asList(oldSegment,
 activeSegment)));
+        when(mockLog.lastStableOffset()).thenReturn(250L);
+
+        RemoteLogManager.RLMTask task = remoteLogManager.new 
RLMTask(leaderTopicIdPartition);
+        task.convertToFollower();
+        task.copyLogSegmentsToRemote();
+
+        // verify the remoteLogMetadataManager never add any metadata and 
remoteStorageManager never copy log segments
+        verify(remoteLogMetadataManager, 
never()).addRemoteLogSegmentMetadata(any(RemoteLogSegmentMetadata.class));
+        verify(remoteStorageManager, 
never()).copyLogSegmentData(any(RemoteLogSegmentMetadata.class), 
any(LogSegmentData.class));
+        verify(remoteLogMetadataManager, 
never()).updateRemoteLogSegmentMetadata(any(RemoteLogSegmentMetadataUpdate.class));
+        verify(mockLog, never()).updateHighestOffsetInRemoteStorage(anyLong());
+    }
+
+    private void verifyRemoteLogSegmentMetadata(RemoteLogSegmentMetadata 
remoteLogSegmentMetadata,
+                                                long oldSegmentStartOffset,
+                                                long oldSegmentEndOffset,
+                                                Map<Integer, Long> 
expectedLeaderEpochs) {
+        assertEquals(leaderTopicIdPartition, 
remoteLogSegmentMetadata.remoteLogSegmentId().topicIdPartition());
+        assertEquals(oldSegmentStartOffset, 
remoteLogSegmentMetadata.startOffset());
+        assertEquals(oldSegmentEndOffset, 
remoteLogSegmentMetadata.endOffset());
+
+        NavigableMap<Integer, Long> leaderEpochs = 
remoteLogSegmentMetadata.segmentLeaderEpochs();
+        assertEquals(expectedLeaderEpochs.size(), leaderEpochs.size());
+        Iterator<Map.Entry<Integer, Long>> leaderEpochEntries = 
expectedLeaderEpochs.entrySet().iterator();
+        assertEquals(leaderEpochEntries.next(), leaderEpochs.firstEntry());
+        assertEquals(leaderEpochEntries.next(), leaderEpochs.lastEntry());
+
+        assertEquals(brokerId, remoteLogSegmentMetadata.brokerId());
+        assertEquals(RemoteLogSegmentState.COPY_SEGMENT_STARTED, 
remoteLogSegmentMetadata.state());
+    }
+
+    private void 
verifyRemoteLogSegmentMetadataUpdate(RemoteLogSegmentMetadataUpdate 
remoteLogSegmentMetadataUpdate) {
+        assertEquals(leaderTopicIdPartition, 
remoteLogSegmentMetadataUpdate.remoteLogSegmentId().topicIdPartition());
+        assertEquals(brokerId, remoteLogSegmentMetadataUpdate.brokerId());
+
+        assertEquals(RemoteLogSegmentState.COPY_SEGMENT_FINISHED, 
remoteLogSegmentMetadataUpdate.state());
+    }
+
+    private void verifyLogSegmentData(LogSegmentData logSegmentData,
+                                      LazyIndex idx,
+                                      LazyIndex timeIdx,
+                                      TransactionIndex txnIndex,
+                                      File tempFile,
+                                      File mockProducerSnapshotIndex,
+                                      List<EpochEntry> expectedLeaderEpoch) 
throws IOException {
+        assertEquals(idx.file().getAbsolutePath(), 
logSegmentData.offsetIndex().toAbsolutePath().toString());
+        assertEquals(timeIdx.file().getAbsolutePath(), 
logSegmentData.timeIndex().toAbsolutePath().toString());
+        assertEquals(txnIndex.file().getPath(), 
logSegmentData.transactionIndex().get().toAbsolutePath().toString());
+        assertEquals(tempFile.getAbsolutePath(), 
logSegmentData.logSegment().toAbsolutePath().toString());
+        assertEquals(mockProducerSnapshotIndex.getAbsolutePath(), 
logSegmentData.producerSnapshotIndex().toAbsolutePath().toString());
+
+        InMemoryLeaderEpochCheckpoint inMemoryLeaderEpochCheckpoint = new 
InMemoryLeaderEpochCheckpoint();
+        inMemoryLeaderEpochCheckpoint.write(expectedLeaderEpoch);
+        assertEquals(inMemoryLeaderEpochCheckpoint.readAsByteBuffer(), 
logSegmentData.leaderEpochIndex());
+    }
+
+    @Test
+    void testGetClassLoaderAwareRemoteStorageManager() throws Exception {
+        ClassLoaderAwareRemoteStorageManager rsmManager = 
mock(ClassLoaderAwareRemoteStorageManager.class);
+        RemoteLogManager remoteLogManager =
+            new RemoteLogManager(remoteLogManagerConfig, brokerId, logDir, 
time, t -> Optional.empty()) {
+                public RemoteStorageManager createRemoteStorageManager() {
+                    return rsmManager;
+                }
+            };
+        assertEquals(rsmManager, remoteLogManager.storageManager());
+    }
+
+    private void verifyInCache(TopicIdPartition... topicIdPartitions) {
+        Arrays.stream(topicIdPartitions).forEach(topicIdPartition -> {
+            assertDoesNotThrow(() -> 
remoteLogManager.fetchRemoteLogSegmentMetadata(topicIdPartition.topicPartition(),
 0, 0L));
+        });
+    }
+
+    private void verifyNotInCache(TopicIdPartition... topicIdPartitions) {
+        Arrays.stream(topicIdPartitions).forEach(topicIdPartition -> {
+            assertThrows(KafkaException.class, () ->
+                
remoteLogManager.fetchRemoteLogSegmentMetadata(topicIdPartition.topicPartition(),
 0, 0L));
+        });
+    }
+
+    @Test
+    void testTopicIdCacheUpdates() throws RemoteStorageException {
+        Partition mockLeaderPartition = mockPartition(leaderTopicIdPartition);
+        Partition mockFollowerPartition = 
mockPartition(followerTopicIdPartition);
+
+        
when(remoteLogMetadataManager.remoteLogSegmentMetadata(any(TopicIdPartition.class),
 anyInt(), anyLong()))
+            .thenReturn(Optional.empty());
+        verifyNotInCache(followerTopicIdPartition, leaderTopicIdPartition);
+        // Load topicId cache
+        
remoteLogManager.onLeadershipChange(Collections.singleton(mockLeaderPartition), 
Collections.singleton(mockFollowerPartition), topicIds);
+        verify(remoteLogMetadataManager, times(1))
+            
.onPartitionLeadershipChanges(Collections.singleton(leaderTopicIdPartition), 
Collections.singleton(followerTopicIdPartition));
+        verifyInCache(followerTopicIdPartition, leaderTopicIdPartition);
+
+        // Evicts from topicId cache
+        
remoteLogManager.stopPartitions(leaderTopicIdPartition.topicPartition(), true);
+        verifyNotInCache(leaderTopicIdPartition);
+        verifyInCache(followerTopicIdPartition);
+
+        // Evicts from topicId cache
+        
remoteLogManager.stopPartitions(followerTopicIdPartition.topicPartition(), 
true);
+        verifyNotInCache(leaderTopicIdPartition, followerTopicIdPartition);
+    }
+
+    @Test
+    void testFetchRemoteLogSegmentMetadata() throws RemoteStorageException {
+        remoteLogManager.onLeadershipChange(
+            Collections.singleton(mockPartition(leaderTopicIdPartition)), 
Collections.singleton(mockPartition(followerTopicIdPartition)), topicIds);
+        
remoteLogManager.fetchRemoteLogSegmentMetadata(leaderTopicIdPartition.topicPartition(),
 10, 100L);
+        
remoteLogManager.fetchRemoteLogSegmentMetadata(followerTopicIdPartition.topicPartition(),
 20, 200L);
+
+        verify(remoteLogMetadataManager)
+            .remoteLogSegmentMetadata(eq(leaderTopicIdPartition), anyInt(), 
anyLong());
+        verify(remoteLogMetadataManager)
+            .remoteLogSegmentMetadata(eq(followerTopicIdPartition), anyInt(), 
anyLong());
+    }
+
+    @Test
+    void testOnLeadershipChangeWillInvokeHandleLeaderOrFollowerPartitions() {
+        RemoteLogManager spyRemoteLogManager = spy(remoteLogManager);
+        spyRemoteLogManager.onLeadershipChange(
+            Collections.emptySet(), 
Collections.singleton(mockPartition(followerTopicIdPartition)), topicIds);
+        
verify(spyRemoteLogManager).doHandleLeaderOrFollowerPartitions(eq(followerTopicIdPartition),
 any(java.util.function.Consumer.class));
+
+        Mockito.reset(spyRemoteLogManager);
+
+        spyRemoteLogManager.onLeadershipChange(
+            Collections.singleton(mockPartition(leaderTopicIdPartition)), 
Collections.emptySet(), topicIds);
+        
verify(spyRemoteLogManager).doHandleLeaderOrFollowerPartitions(eq(leaderTopicIdPartition),
 any(java.util.function.Consumer.class));
+    }
+
+    private MemoryRecords records(long timestamp,
+                                  long initialOffset,
+                                  int partitionLeaderEpoch) {
+        return MemoryRecords.withRecords(initialOffset, CompressionType.NONE, 
partitionLeaderEpoch,
+            new SimpleRecord(timestamp - 1, "first message".getBytes()),
+            new SimpleRecord(timestamp + 1, "second message".getBytes()),
+            new SimpleRecord(timestamp + 2, "third message".getBytes())
+            );
+    }
+
+    @Test
+    void testRLMTaskShouldSetLeaderEpochCorrectly() {
+        RemoteLogManager.RLMTask task = remoteLogManager.new 
RLMTask(leaderTopicIdPartition);
+        assertFalse(task.isLeader());
+        task.convertToLeader(1);
+        assertTrue(task.isLeader());
+        task.convertToFollower();
+        assertFalse(task.isLeader());
+    }
+
+    @Test
+    void testFindOffsetByTimestamp() throws IOException, 
RemoteStorageException {
+        TopicPartition tp = leaderTopicIdPartition.topicPartition();
+        RemoteLogSegmentId remoteLogSegmentId = new 
RemoteLogSegmentId(leaderTopicIdPartition, Uuid.randomUuid());
+        long ts = time.milliseconds();
+        long startOffset = 120;
+        int targetLeaderEpoch = 10;
+
+        RemoteLogSegmentMetadata segmentMetadata = 
mock(RemoteLogSegmentMetadata.class);
+        
when(segmentMetadata.remoteLogSegmentId()).thenReturn(remoteLogSegmentId);
+        when(segmentMetadata.maxTimestampMs()).thenReturn(ts + 2);
+        when(segmentMetadata.startOffset()).thenReturn(startOffset);
+        when(segmentMetadata.endOffset()).thenReturn(startOffset + 2);
+
+        File tpDir = new File(logDir, tp.toString());
+        Files.createDirectory(tpDir.toPath());
+        File txnIdxFile = new File(tpDir, "txn-index" + 
UnifiedLog.TxnIndexFileSuffix());
+        txnIdxFile.createNewFile();
+        
when(remoteStorageManager.fetchIndex(any(RemoteLogSegmentMetadata.class), 
any(IndexType.class)))
+            .thenAnswer(ans -> {
+                RemoteLogSegmentMetadata metadata = 
ans.<RemoteLogSegmentMetadata>getArgument(0);
+                IndexType indexType = ans.<IndexType>getArgument(1);
+                int maxEntries = (int) (metadata.endOffset() - 
metadata.startOffset());
+                OffsetIndex offsetIdx = new OffsetIndex(new File(tpDir, 
String.valueOf(metadata.startOffset()) + UnifiedLog.IndexFileSuffix()),
+                    metadata.startOffset(), maxEntries * 8);
+                TimeIndex timeIdx = new TimeIndex(new File(tpDir, 
String.valueOf(metadata.startOffset()) + UnifiedLog.TimeIndexFileSuffix()),
+                    metadata.startOffset(), maxEntries * 12);
+                switch (indexType) {
+                    case OFFSET:
+                        return new FileInputStream(offsetIdx.file());
+                    case TIMESTAMP:
+                        return new FileInputStream(timeIdx.file());
+                    case TRANSACTION:
+                        return new FileInputStream(txnIdxFile);
+                }
+                return null;
+            });
+
+        
when(remoteLogMetadataManager.listRemoteLogSegments(eq(leaderTopicIdPartition), 
anyInt()))
+            .thenAnswer(ans -> {
+                int leaderEpoch = ans.<Integer>getArgument(1);
+                if (leaderEpoch == targetLeaderEpoch)
+                    return Collections.singleton(segmentMetadata).iterator();
+                else
+                    return Collections.emptyList().iterator();
+            });
+
+
+
+        // 3 messages are added with offset, and timestamp as below
+        // startOffset   , ts-1
+        // startOffset+1 , ts+1
+        // startOffset+2 , ts+2
+        when(remoteStorageManager.fetchLogSegment(segmentMetadata, 0))
+            .thenAnswer(a -> new ByteArrayInputStream(records(ts, startOffset, 
targetLeaderEpoch).buffer().array()));
+
+        LeaderEpochFileCache leaderEpochFileCache = new 
LeaderEpochFileCache(tp, checkpoint);
+        leaderEpochFileCache.assign(5, 99L);
+        leaderEpochFileCache.assign(targetLeaderEpoch, startOffset);
+        leaderEpochFileCache.assign(12, 500L);
+
+        
remoteLogManager.onLeadershipChange(Collections.singleton(mockPartition(leaderTopicIdPartition)),
 Collections.emptySet(), topicIds);
+        // Fetching message for timestamp `ts` will return the message with 
startOffset+1, and `ts+1` as there are no
+        // messages starting with the startOffset and with `ts`.
+        Optional<FileRecords.TimestampAndOffset> maybeTimestampAndOffset1 = 
remoteLogManager.findOffsetByTimestamp(tp, ts, startOffset, 
leaderEpochFileCache);
+        assertEquals(Optional.of(new FileRecords.TimestampAndOffset(ts + 1, 
startOffset + 1, Optional.of(targetLeaderEpoch))), maybeTimestampAndOffset1);
+
+        // Fetching message for `ts+2` will return the message with 
startOffset+2 and its timestamp value is `ts+2`.
+        Optional<FileRecords.TimestampAndOffset> maybeTimestampAndOffset2 = 
remoteLogManager.findOffsetByTimestamp(tp, ts + 2, startOffset, 
leaderEpochFileCache);
+        assertEquals(Optional.of(new FileRecords.TimestampAndOffset(ts + 2, 
startOffset + 2, Optional.of(targetLeaderEpoch))), maybeTimestampAndOffset2);
+
+        // Fetching message for `ts+3` will return None as there are no 
records with timestamp >= ts+3.
+        Optional<FileRecords.TimestampAndOffset>  maybeTimestampAndOffset3 = 
remoteLogManager.findOffsetByTimestamp(tp, ts + 3, startOffset, 
leaderEpochFileCache);
+        assertEquals(Optional.empty(), maybeTimestampAndOffset3);
+    }
+
+    @Test
+    void testIdempotentClose() throws IOException {
+        remoteLogManager.close();
+        remoteLogManager.close();
+        InOrder inorder = inOrder(remoteStorageManager, 
remoteLogMetadataManager);
+        inorder.verify(remoteStorageManager, times(1)).close();
+        inorder.verify(remoteLogMetadataManager, times(1)).close();
+    }
+
+    private Partition mockPartition(TopicIdPartition topicIdPartition) {
+        TopicPartition tp = topicIdPartition.topicPartition();
+        Partition partition = mock(Partition.class);
+        UnifiedLog log = mock(UnifiedLog.class);
+        when(partition.topicPartition()).thenReturn(tp);
+        when(partition.topic()).thenReturn(tp.topic());
+        when(log.remoteLogEnabled()).thenReturn(true);
+        when(partition.log()).thenReturn(Option.apply(log));
+        return partition;
+    }
+
+    private RemoteLogManagerConfig createRLMConfig(Properties props) {
+        
props.put(RemoteLogManagerConfig.REMOTE_LOG_STORAGE_SYSTEM_ENABLE_PROP, true);
+        
props.put(RemoteLogManagerConfig.REMOTE_STORAGE_MANAGER_CLASS_NAME_PROP, 
NoOpRemoteStorageManager.class.getName());
+        
props.put(RemoteLogManagerConfig.REMOTE_LOG_METADATA_MANAGER_CLASS_NAME_PROP, 
NoOpRemoteLogMetadataManager.class.getName());
+        AbstractConfig config = new 
AbstractConfig(RemoteLogManagerConfig.CONFIG_DEF, props);
+        return new RemoteLogManagerConfig(config);
+    }
+
+}
diff --git a/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala 
b/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
index e54dab2e493..7de7288a1c1 100755
--- a/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
+++ b/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
@@ -2035,7 +2035,7 @@ class UnifiedLogTest {
       remoteLogStorageEnable = true)
     val log = createLog(logDir, logConfig, remoteStorageSystemEnable = true, 
remoteLogManager = Some(remoteLogManager))
     when(remoteLogManager.findOffsetByTimestamp(log.topicPartition, 0, 0, 
log.leaderEpochCache.get))
-      .thenReturn(None)
+      .thenReturn(Optional.empty[TimestampAndOffset]())
     assertEquals(None, log.fetchOffsetByTimestamp(0L))
 
     val firstTimestamp = mockTime.milliseconds
@@ -2056,9 +2056,9 @@ class UnifiedLogTest {
       anyLong(), anyLong(), ArgumentMatchers.eq(log.leaderEpochCache.get)))
       .thenAnswer(ans => {
         val timestamp = ans.getArgument(1).asInstanceOf[Long]
-        Option(timestamp)
+        Optional.of(timestamp)
           .filter(_ == firstTimestamp)
-          .map(new TimestampAndOffset(_, 0L, Optional.of(firstLeaderEpoch)))
+          .map[TimestampAndOffset](x => new TimestampAndOffset(x, 0L, 
Optional.of(firstLeaderEpoch)))
       })
     log._localLogStartOffset = 1
 
diff --git 
a/core/src/test/scala/unit/kafka/log/remote/RemoteLogManagerTest.scala 
b/core/src/test/scala/unit/kafka/log/remote/RemoteLogManagerTest.scala
deleted file mode 100644
index 237ff4682f7..00000000000
--- a/core/src/test/scala/unit/kafka/log/remote/RemoteLogManagerTest.scala
+++ /dev/null
@@ -1,277 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package kafka.log.remote
-
-import kafka.cluster.Partition
-import kafka.log.UnifiedLog
-import kafka.server.KafkaConfig
-import kafka.utils.MockTime
-import org.apache.kafka.common.config.AbstractConfig
-import org.apache.kafka.common.record.FileRecords.TimestampAndOffset
-import org.apache.kafka.common.record.{CompressionType, MemoryRecords, 
SimpleRecord}
-import org.apache.kafka.common.{KafkaException, TopicIdPartition, 
TopicPartition, Uuid}
-import 
org.apache.kafka.server.log.remote.storage.RemoteStorageManager.IndexType
-import org.apache.kafka.server.log.remote.storage._
-import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpoint
-import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache
-import org.apache.kafka.storage.internals.log.{EpochEntry, OffsetIndex, 
TimeIndex}
-import org.apache.kafka.test.TestUtils
-import org.junit.jupiter.api.Assertions._
-import org.junit.jupiter.api.{BeforeEach, Test}
-import org.mockito.ArgumentMatchers.{any, anyInt, anyLong}
-import org.mockito.Mockito._
-import org.mockito.{ArgumentCaptor, ArgumentMatchers}
-
-import java.io.{ByteArrayInputStream, File, FileInputStream}
-import java.nio.file.Files
-import java.util
-import java.util.{Optional, Properties}
-import scala.collection.Seq
-import scala.jdk.CollectionConverters._
-
-class RemoteLogManagerTest {
-
-  val time = new MockTime()
-  val brokerId = 0
-  val logDir: String = TestUtils.tempDirectory("kafka-").toString
-
-  val remoteStorageManager: RemoteStorageManager = 
mock(classOf[RemoteStorageManager])
-  val remoteLogMetadataManager: RemoteLogMetadataManager = 
mock(classOf[RemoteLogMetadataManager])
-  var remoteLogManagerConfig: RemoteLogManagerConfig = _
-  var remoteLogManager: RemoteLogManager = _
-
-  val leaderTopicIdPartition =  new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("Leader", 0))
-  val followerTopicIdPartition = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("Follower", 0))
-  val topicIds: util.Map[String, Uuid] = Map(
-    leaderTopicIdPartition.topicPartition().topic() -> 
leaderTopicIdPartition.topicId(),
-    followerTopicIdPartition.topicPartition().topic() -> 
followerTopicIdPartition.topicId()
-  ).asJava
-
-  val checkpoint: LeaderEpochCheckpoint = new LeaderEpochCheckpoint {
-    var epochs: Seq[EpochEntry] = Seq()
-    override def write(epochs: util.Collection[EpochEntry]): Unit = 
this.epochs = epochs.asScala.toSeq
-    override def read(): util.List[EpochEntry] = this.epochs.asJava
-  }
-
-  @BeforeEach
-  def setUp(): Unit = {
-    val props = new Properties()
-    remoteLogManagerConfig = createRLMConfig(props)
-    remoteLogManager = new RemoteLogManager(remoteLogManagerConfig, brokerId, 
logDir) {
-      override private[remote] def createRemoteStorageManager() = 
remoteStorageManager
-      override private[remote] def createRemoteLogMetadataManager() = 
remoteLogMetadataManager
-    }
-  }
-
-  @Test
-  def testRemoteLogMetadataManagerWithUserDefinedConfigs(): Unit = {
-    val key = "key"
-    val configPrefix = "config.prefix"
-    val props: Properties = new Properties()
-    
props.put(RemoteLogManagerConfig.REMOTE_LOG_METADATA_MANAGER_CONFIG_PREFIX_PROP,
 configPrefix)
-    props.put(configPrefix + key, "world")
-    props.put("remote.log.metadata.y", "z")
-
-    val metadataMangerConfig = 
createRLMConfig(props).remoteLogMetadataManagerProps()
-    assertEquals(props.get(configPrefix + key), metadataMangerConfig.get(key))
-    assertFalse(metadataMangerConfig.containsKey("remote.log.metadata.y"))
-  }
-
-  @Test
-  def testStartup(): Unit = {
-    remoteLogManager.startup()
-    val capture: ArgumentCaptor[util.Map[String, _]] = 
ArgumentCaptor.forClass(classOf[util.Map[String, _]])
-    verify(remoteStorageManager, times(1)).configure(capture.capture())
-    assertEquals(brokerId, capture.getValue.get(KafkaConfig.BrokerIdProp))
-
-    verify(remoteLogMetadataManager, times(1)).configure(capture.capture())
-    assertEquals(brokerId, capture.getValue.get(KafkaConfig.BrokerIdProp))
-    assertEquals(logDir, capture.getValue.get(KafkaConfig.LogDirProp))
-  }
-
-  @Test
-  def testGetClassLoaderAwareRemoteStorageManager(): Unit = {
-    val rsmManager: ClassLoaderAwareRemoteStorageManager = 
mock(classOf[ClassLoaderAwareRemoteStorageManager])
-    val remoteLogManager =
-      new RemoteLogManager(remoteLogManagerConfig, brokerId, logDir) {
-        override private[remote] def createRemoteStorageManager(): 
ClassLoaderAwareRemoteStorageManager = rsmManager
-      }
-    assertEquals(rsmManager, remoteLogManager.storageManager())
-  }
-
-  @Test
-  def testTopicIdCacheUpdates(): Unit = {
-    def verifyInCache(topicIdPartitions: TopicIdPartition*): Unit = {
-      topicIdPartitions.foreach { topicIdPartition =>
-        assertDoesNotThrow(() =>
-          
remoteLogManager.fetchRemoteLogSegmentMetadata(topicIdPartition.topicPartition(),
 epochForOffset = 0, offset = 0L))
-      }
-    }
-
-    def verifyNotInCache(topicIdPartitions: TopicIdPartition*): Unit = {
-      topicIdPartitions.foreach { topicIdPartition =>
-        assertThrows(classOf[KafkaException], () =>
-          
remoteLogManager.fetchRemoteLogSegmentMetadata(topicIdPartition.topicPartition(),
 epochForOffset = 0, offset = 0L))
-      }
-    }
-
-    val mockLeaderPartition = mockPartition(leaderTopicIdPartition)
-    val mockFollowerPartition = mockPartition(followerTopicIdPartition)
-
-    
when(remoteLogMetadataManager.remoteLogSegmentMetadata(any(classOf[TopicIdPartition]),
 anyInt(), anyLong()))
-      .thenReturn(Optional.empty[RemoteLogSegmentMetadata]())
-    verifyNotInCache(followerTopicIdPartition, leaderTopicIdPartition)
-    // Load topicId cache
-    remoteLogManager.onLeadershipChange(Set(mockLeaderPartition), 
Set(mockFollowerPartition), topicIds)
-    verify(remoteLogMetadataManager, times(1))
-      .onPartitionLeadershipChanges(Set(leaderTopicIdPartition).asJava, 
Set(followerTopicIdPartition).asJava)
-    verifyInCache(followerTopicIdPartition, leaderTopicIdPartition)
-
-    // Evicts from topicId cache
-    remoteLogManager.stopPartitions(leaderTopicIdPartition.topicPartition(), 
delete = true)
-    verifyNotInCache(leaderTopicIdPartition)
-    verifyInCache(followerTopicIdPartition)
-
-    // Evicts from topicId cache
-    remoteLogManager.stopPartitions(followerTopicIdPartition.topicPartition(), 
delete = true)
-    verifyNotInCache(leaderTopicIdPartition, followerTopicIdPartition)
-  }
-
-  @Test
-  def testFetchRemoteLogSegmentMetadata(): Unit = {
-    remoteLogManager.onLeadershipChange(
-      Set(mockPartition(leaderTopicIdPartition)), 
Set(mockPartition(followerTopicIdPartition)), topicIds)
-    
remoteLogManager.fetchRemoteLogSegmentMetadata(leaderTopicIdPartition.topicPartition(),
 10, 100L)
-    
remoteLogManager.fetchRemoteLogSegmentMetadata(followerTopicIdPartition.topicPartition(),
 20, 200L)
-
-    verify(remoteLogMetadataManager)
-      .remoteLogSegmentMetadata(ArgumentMatchers.eq(leaderTopicIdPartition), 
anyInt(), anyLong())
-    verify(remoteLogMetadataManager)
-      .remoteLogSegmentMetadata(ArgumentMatchers.eq(followerTopicIdPartition), 
anyInt(), anyLong())
-  }
-
-  @Test
-  def testFindOffsetByTimestamp(): Unit = {
-    val tp = leaderTopicIdPartition.topicPartition()
-    val remoteLogSegmentId = new RemoteLogSegmentId(leaderTopicIdPartition, 
Uuid.randomUuid())
-    val ts = time.milliseconds()
-    val startOffset = 120
-    val targetLeaderEpoch = 10
-
-    val segmentMetadata = mock(classOf[RemoteLogSegmentMetadata])
-    when(segmentMetadata.remoteLogSegmentId()).thenReturn(remoteLogSegmentId)
-    when(segmentMetadata.maxTimestampMs()).thenReturn(ts + 2)
-    when(segmentMetadata.startOffset()).thenReturn(startOffset)
-    when(segmentMetadata.endOffset()).thenReturn(startOffset + 2)
-
-    val tpDir: File = new File(logDir, tp.toString)
-    Files.createDirectory(tpDir.toPath)
-    val txnIdxFile = new File(tpDir, "txn-index" + 
UnifiedLog.TxnIndexFileSuffix)
-    txnIdxFile.createNewFile()
-    
when(remoteStorageManager.fetchIndex(any(classOf[RemoteLogSegmentMetadata]), 
any(classOf[IndexType])))
-      .thenAnswer { ans =>
-        val metadata = ans.getArgument[RemoteLogSegmentMetadata](0)
-        val indexType = ans.getArgument[IndexType](1)
-        val maxEntries = (metadata.endOffset() - 
metadata.startOffset()).asInstanceOf[Int]
-        val offsetIdx = new OffsetIndex(new File(tpDir, 
String.valueOf(metadata.startOffset()) + UnifiedLog.IndexFileSuffix),
-          metadata.startOffset(), maxEntries * 8)
-        val timeIdx = new TimeIndex(new File(tpDir, 
String.valueOf(metadata.startOffset()) + UnifiedLog.TimeIndexFileSuffix),
-          metadata.startOffset(), maxEntries * 12)
-        indexType match {
-          case IndexType.OFFSET => new FileInputStream(offsetIdx.file)
-          case IndexType.TIMESTAMP => new FileInputStream(timeIdx.file)
-          case IndexType.TRANSACTION => new FileInputStream(txnIdxFile)
-          case IndexType.LEADER_EPOCH =>
-          case IndexType.PRODUCER_SNAPSHOT =>
-        }
-      }
-
-    
when(remoteLogMetadataManager.listRemoteLogSegments(ArgumentMatchers.eq(leaderTopicIdPartition),
 anyInt()))
-      .thenAnswer(ans => {
-        val leaderEpoch = ans.getArgument[Int](1)
-        if (leaderEpoch == targetLeaderEpoch)
-          List(segmentMetadata).asJava.iterator()
-        else
-          List().asJava.iterator()
-      })
-
-    def records(timestamp: Long,
-                initialOffset: Long,
-                partitionLeaderEpoch: Int): MemoryRecords = {
-      MemoryRecords.withRecords(initialOffset, CompressionType.NONE, 
partitionLeaderEpoch,
-        new SimpleRecord(timestamp - 1, "first message".getBytes()),
-        new SimpleRecord(timestamp + 1, "second message".getBytes()),
-        new SimpleRecord(timestamp + 2, "third message".getBytes()),
-      )
-    }
-
-    // 3 messages are added with offset, and timestamp as below
-    // startOffset   , ts-1
-    // startOffset+1 , ts+1
-    // startOffset+2 , ts+2
-    when(remoteStorageManager.fetchLogSegment(segmentMetadata, 0))
-      .thenAnswer(_ => new ByteArrayInputStream(records(ts, startOffset, 
targetLeaderEpoch).buffer().array()))
-
-    val leaderEpochFileCache = new LeaderEpochFileCache(tp, checkpoint)
-    leaderEpochFileCache.assign(5, 99L)
-    leaderEpochFileCache.assign(targetLeaderEpoch, startOffset)
-    leaderEpochFileCache.assign(12, 500L)
-
-    
remoteLogManager.onLeadershipChange(Set(mockPartition(leaderTopicIdPartition)), 
Set(), topicIds)
-    // Fetching message for timestamp `ts` will return the message with 
startOffset+1, and `ts+1` as there are no
-    // messages starting with the startOffset and with `ts`.
-    val maybeTimestampAndOffset1 = remoteLogManager.findOffsetByTimestamp(tp, 
ts, startOffset, leaderEpochFileCache)
-    assertEquals(Some(new TimestampAndOffset(ts + 1, startOffset + 1, 
Optional.of(targetLeaderEpoch))), maybeTimestampAndOffset1)
-
-    // Fetching message for `ts+2` will return the message with startOffset+2 
and its timestamp value is `ts+2`.
-    val maybeTimestampAndOffset2 = remoteLogManager.findOffsetByTimestamp(tp, 
ts + 2, startOffset, leaderEpochFileCache)
-    assertEquals(Some(new TimestampAndOffset(ts + 2, startOffset + 2, 
Optional.of(targetLeaderEpoch))), maybeTimestampAndOffset2)
-
-    // Fetching message for `ts+3` will return None as there are no records 
with timestamp >= ts+3.
-    val maybeTimestampAndOffset3 = remoteLogManager.findOffsetByTimestamp(tp, 
ts + 3, startOffset, leaderEpochFileCache)
-    assertEquals(None, maybeTimestampAndOffset3)
-  }
-
-  @Test
-  def testIdempotentClose(): Unit = {
-    remoteLogManager.close()
-    remoteLogManager.close()
-    val inorder = inOrder(remoteStorageManager, remoteLogMetadataManager)
-    inorder.verify(remoteStorageManager, times(1)).close()
-    inorder.verify(remoteLogMetadataManager, times(1)).close()
-  }
-
-  private def mockPartition(topicIdPartition: TopicIdPartition) = {
-    val tp = topicIdPartition.topicPartition()
-    val partition: Partition = mock(classOf[Partition])
-    val log = mock(classOf[UnifiedLog])
-    when(partition.topicPartition).thenReturn(tp)
-    when(partition.topic).thenReturn(tp.topic())
-    when(log.remoteLogEnabled()).thenReturn(true)
-    when(partition.log).thenReturn(Some(log))
-    partition
-  }
-
-  private def createRLMConfig(props: Properties): RemoteLogManagerConfig = {
-    props.put(RemoteLogManagerConfig.REMOTE_LOG_STORAGE_SYSTEM_ENABLE_PROP, 
true.toString)
-    props.put(RemoteLogManagerConfig.REMOTE_STORAGE_MANAGER_CLASS_NAME_PROP, 
classOf[NoOpRemoteStorageManager].getName)
-    
props.put(RemoteLogManagerConfig.REMOTE_LOG_METADATA_MANAGER_CLASS_NAME_PROP, 
classOf[NoOpRemoteLogMetadataManager].getName)
-    val config = new AbstractConfig(RemoteLogManagerConfig.CONFIG_DEF, props)
-    new RemoteLogManagerConfig(config)
-  }
-
-}
diff --git 
a/server-common/src/main/java/org/apache/kafka/server/common/CheckpointFile.java
 
b/server-common/src/main/java/org/apache/kafka/server/common/CheckpointFile.java
index aca730baf6e..747a16a6d29 100644
--- 
a/server-common/src/main/java/org/apache/kafka/server/common/CheckpointFile.java
+++ 
b/server-common/src/main/java/org/apache/kafka/server/common/CheckpointFile.java
@@ -77,20 +77,8 @@ public class CheckpointFile<T> {
             // write to temp file and then swap with the existing file
             try (FileOutputStream fileOutputStream = new 
FileOutputStream(tempPath.toFile());
                  BufferedWriter writer = new BufferedWriter(new 
OutputStreamWriter(fileOutputStream, StandardCharsets.UTF_8))) {
-                // Write the version
-                writer.write(Integer.toString(version));
-                writer.newLine();
-
-                // Write the entries count
-                writer.write(Integer.toString(entries.size()));
-                writer.newLine();
-
-                // Write each entry on a new line.
-                for (T entry : entries) {
-                    writer.write(formatter.toString(entry));
-                    writer.newLine();
-                }
-
+                CheckpointWriteBuffer<T> checkpointWriteBuffer = new 
CheckpointWriteBuffer<>(writer, version, formatter);
+                checkpointWriteBuffer.write(entries);
                 writer.flush();
                 fileOutputStream.getFD().sync();
             }
@@ -109,25 +97,28 @@ public class CheckpointFile<T> {
     }
 
     public static class CheckpointWriteBuffer<T> {
-        private BufferedWriter writer;
-        private int version;
-        private EntryFormatter<T> formatter;
+        private final BufferedWriter writer;
+        private final int version;
+        private final EntryFormatter<T> formatter;
 
         public CheckpointWriteBuffer(BufferedWriter writer,
                                      int version,
                                      EntryFormatter<T> formatter) {
-            this.version = version;
             this.writer = writer;
+            this.version = version;
             this.formatter = formatter;
         }
 
-        public void write(List<T> entries) throws IOException {
-            writer.write(String.valueOf(version));
+        public void write(Collection<T> entries) throws IOException {
+            // Write the version
+            writer.write(Integer.toString(version));
             writer.newLine();
 
-            writer.write(String.valueOf(entries.size()));
+            // Write the entries count
+            writer.write(Integer.toString(entries.size()));
             writer.newLine();
 
+            // Write each entry on a new line.
             for (T entry : entries) {
                 writer.write(formatter.toString(entry));
                 writer.newLine();
diff --git 
a/storage/src/main/java/org/apache/kafka/storage/internals/checkpoint/InMemoryLeaderEpochCheckpoint.java
 
b/storage/src/main/java/org/apache/kafka/storage/internals/checkpoint/InMemoryLeaderEpochCheckpoint.java
index cd7fdc2f893..3ef30b2502a 100644
--- 
a/storage/src/main/java/org/apache/kafka/storage/internals/checkpoint/InMemoryLeaderEpochCheckpoint.java
+++ 
b/storage/src/main/java/org/apache/kafka/storage/internals/checkpoint/InMemoryLeaderEpochCheckpoint.java
@@ -52,14 +52,12 @@ public class InMemoryLeaderEpochCheckpoint implements 
LeaderEpochCheckpoint {
 
     public ByteBuffer readAsByteBuffer() throws IOException {
         ByteArrayOutputStream stream = new ByteArrayOutputStream();
-        BufferedWriter writer = new BufferedWriter(new 
OutputStreamWriter(stream, StandardCharsets.UTF_8));
-        CheckpointFile.CheckpointWriteBuffer<EpochEntry> writeBuffer = new 
CheckpointFile.CheckpointWriteBuffer<>(writer, 0, 
LeaderEpochCheckpointFile.FORMATTER);
-        try {
+        try (BufferedWriter writer = new BufferedWriter(new 
OutputStreamWriter(stream, StandardCharsets.UTF_8));) {
+            CheckpointFile.CheckpointWriteBuffer<EpochEntry> writeBuffer = new 
CheckpointFile.CheckpointWriteBuffer<>(writer, 0, 
LeaderEpochCheckpointFile.FORMATTER);
             writeBuffer.write(epochs);
             writer.flush();
-            return ByteBuffer.wrap(stream.toByteArray());
-        } finally {
-            writer.close();
         }
+
+        return ByteBuffer.wrap(stream.toByteArray());
     }
 }
\ No newline at end of file
diff --git 
a/storage/src/main/java/org/apache/kafka/storage/internals/epoch/LeaderEpochFileCache.java
 
b/storage/src/main/java/org/apache/kafka/storage/internals/epoch/LeaderEpochFileCache.java
index 1db53133578..55d7be4029a 100644
--- 
a/storage/src/main/java/org/apache/kafka/storage/internals/epoch/LeaderEpochFileCache.java
+++ 
b/storage/src/main/java/org/apache/kafka/storage/internals/epoch/LeaderEpochFileCache.java
@@ -43,13 +43,13 @@ import static 
org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.UND
  * Offset = offset of the first message in each epoch.
  */
 public class LeaderEpochFileCache {
+    private final TopicPartition topicPartition;
     private final LeaderEpochCheckpoint checkpoint;
     private final Logger log;
 
     private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
     private final TreeMap<Integer, EpochEntry> epochs = new TreeMap<>();
 
-    private final TopicPartition topicPartition;
 
     /**
      * @param topicPartition the associated topic partition
@@ -366,6 +366,16 @@ public class LeaderEpochFileCache {
         }
     }
 
+    public LeaderEpochFileCache writeTo(LeaderEpochCheckpoint 
leaderEpochCheckpoint) {
+        lock.readLock().lock();
+        try {
+            leaderEpochCheckpoint.write(epochEntries());
+            return new LeaderEpochFileCache(topicPartition, 
leaderEpochCheckpoint);
+        } finally {
+            lock.readLock().unlock();
+        }
+    }
+
     /**
      * Delete all entries.
      */
diff --git 
a/storage/src/main/java/org/apache/kafka/storage/internals/log/ProducerStateManager.java
 
b/storage/src/main/java/org/apache/kafka/storage/internals/log/ProducerStateManager.java
index 774e5be1ac1..6212711cb27 100644
--- 
a/storage/src/main/java/org/apache/kafka/storage/internals/log/ProducerStateManager.java
+++ 
b/storage/src/main/java/org/apache/kafka/storage/internals/log/ProducerStateManager.java
@@ -559,6 +559,10 @@ public class ProducerStateManager {
         }
     }
 
+    public Optional<File> fetchSnapshot(long offset) {
+        return Optional.of(snapshots.get(offset)).map(x -> x.file());
+    }
+
     private Optional<SnapshotFile> oldestSnapshotFile() {
         return Optional.ofNullable(snapshots.firstEntry()).map(x -> 
x.getValue());
     }

Reply via email to