This is an automated email from the ASF dual-hosted git repository.

mittal pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 239dce3e041 KAFKA-19291: Increase the timeout of remote storage share 
fetch requests in purgatory (#19757)
239dce3e041 is described below

commit 239dce3e0415061d070b71874ece23ae5df14687
Author: Abhinav Dixit <adi...@confluent.io>
AuthorDate: Thu May 22 12:11:33 2025 +0530

    KAFKA-19291: Increase the timeout of remote storage share fetch requests in 
purgatory (#19757)
    
    ### About
    Consumer groups have a different timeout `REMOTE_FETCH_MAX_WAIT_MS_PROP`
    in delayed fetch purgatory for fetch requests having remote storage
    fetch ([code
    
    
link](https://github.com/apache/kafka/blob/trunk/core/src/main/scala/kafka/server/ReplicaManager.scala#L1669)).
    This is done before the request enters the purgatory, so its easy to
    change. At the moment share groups can only have a `waitTimeMs` `of
    shareFetch.fetchParams().maxWaitMs` (default value `500ms`) for delayed
    share fetch purgatory regardless of whether they are remote
    storage/local log fetch.
    This PR introduces a way to increase the timeout of remote storage fetch
    requests if a remote storage fetch request couldn't complete within
    `shareFetch.fetchParams().maxWaitMs`, then we create a timer task which
    can be interrupted whenever `pendingFetches` is finished. The change has
    been done to avoid the expiration of remote storage share fetch
    requests.
    
    ### Testing
    The code has been tested with the help of unit tests and
    `LocalTieredStorage.java`
    
    Reviewers: Apoorv Mittal <apoorvmitta...@gmail.com>
---
 .../java/kafka/server/share/DelayedShareFetch.java |  77 ++++++++-
 .../kafka/server/share/PendingRemoteFetches.java   |   8 +
 .../kafka/server/share/SharePartitionManager.java  |  12 +-
 .../src/main/scala/kafka/server/BrokerServer.scala |   1 +
 .../kafka/server/share/DelayedShareFetchTest.java  | 178 ++++++++++++++++++++-
 .../server/share/SharePartitionManagerTest.java    |   2 +
 .../unit/kafka/server/ReplicaManagerTest.scala     |   3 +-
 7 files changed, 274 insertions(+), 7 deletions(-)

diff --git a/core/src/main/java/kafka/server/share/DelayedShareFetch.java 
b/core/src/main/java/kafka/server/share/DelayedShareFetch.java
index addfc691f1a..0a46e834eb7 100644
--- a/core/src/main/java/kafka/server/share/DelayedShareFetch.java
+++ b/core/src/main/java/kafka/server/share/DelayedShareFetch.java
@@ -40,6 +40,7 @@ import 
org.apache.kafka.server.share.fetch.ShareFetchPartitionData;
 import org.apache.kafka.server.share.metrics.ShareGroupMetrics;
 import org.apache.kafka.server.storage.log.FetchIsolation;
 import org.apache.kafka.server.storage.log.FetchPartitionData;
+import org.apache.kafka.server.util.timer.TimerTask;
 import org.apache.kafka.storage.internals.log.FetchDataInfo;
 import org.apache.kafka.storage.internals.log.LogOffsetMetadata;
 import org.apache.kafka.storage.internals.log.LogOffsetSnapshot;
@@ -64,6 +65,7 @@ import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.locks.Lock;
 import java.util.function.BiConsumer;
 import java.util.stream.Collectors;
@@ -107,6 +109,8 @@ public class DelayedShareFetch extends DelayedOperation {
     private LinkedHashMap<TopicIdPartition, LogReadResult> 
localPartitionsAlreadyFetched;
     private Optional<PendingRemoteFetches> pendingRemoteFetchesOpt;
     private Optional<Exception> remoteStorageFetchException;
+    private final AtomicBoolean outsidePurgatoryCallbackLock;
+    private final long remoteFetchMaxWaitMs;
 
     /**
      * This function constructs an instance of delayed share fetch operation 
for completing share fetch
@@ -118,6 +122,7 @@ public class DelayedShareFetch extends DelayedOperation {
      * @param sharePartitions The share partitions referenced in the share 
fetch request.
      * @param shareGroupMetrics The share group metrics to record the metrics.
      * @param time The system time.
+     * @param remoteFetchMaxWaitMs The max wait time for a share fetch request 
having remote storage fetch.
      */
     public DelayedShareFetch(
             ShareFetch shareFetch,
@@ -125,7 +130,8 @@ public class DelayedShareFetch extends DelayedOperation {
             BiConsumer<SharePartitionKey, Throwable> exceptionHandler,
             LinkedHashMap<TopicIdPartition, SharePartition> sharePartitions,
             ShareGroupMetrics shareGroupMetrics,
-            Time time
+            Time time,
+            long remoteFetchMaxWaitMs
     ) {
         this(shareFetch,
             replicaManager,
@@ -135,7 +141,8 @@ public class DelayedShareFetch extends DelayedOperation {
             shareGroupMetrics,
             time,
             Optional.empty(),
-            Uuid.randomUuid()
+            Uuid.randomUuid(),
+            remoteFetchMaxWaitMs
         );
     }
 
@@ -151,6 +158,7 @@ public class DelayedShareFetch extends DelayedOperation {
      * @param shareGroupMetrics The share group metrics to record the metrics.
      * @param time The system time.
      * @param pendingRemoteFetchesOpt Optional containing an in-flight remote 
fetch object or an empty optional.
+     * @param remoteFetchMaxWaitMs The max wait time for a share fetch request 
having remote storage fetch.
      */
     DelayedShareFetch(
         ShareFetch shareFetch,
@@ -161,7 +169,8 @@ public class DelayedShareFetch extends DelayedOperation {
         ShareGroupMetrics shareGroupMetrics,
         Time time,
         Optional<PendingRemoteFetches> pendingRemoteFetchesOpt,
-        Uuid fetchId
+        Uuid fetchId,
+        long remoteFetchMaxWaitMs
     ) {
         super(shareFetch.fetchParams().maxWaitMs, Optional.empty());
         this.shareFetch = shareFetch;
@@ -177,6 +186,8 @@ public class DelayedShareFetch extends DelayedOperation {
         this.pendingRemoteFetchesOpt = pendingRemoteFetchesOpt;
         this.remoteStorageFetchException = Optional.empty();
         this.fetchId = fetchId;
+        this.outsidePurgatoryCallbackLock = new AtomicBoolean(false);
+        this.remoteFetchMaxWaitMs = remoteFetchMaxWaitMs;
         // Register metrics for DelayedShareFetch.
         KafkaMetricsGroup metricsGroup = new KafkaMetricsGroup("kafka.server", 
"DelayedShareFetchMetrics");
         this.expiredRequestMeter = metricsGroup.newMeter(EXPIRES_PER_SEC, 
"requests", TimeUnit.SECONDS);
@@ -205,6 +216,12 @@ public class DelayedShareFetch extends DelayedOperation {
             if (remoteStorageFetchException.isPresent()) {
                 completeErroneousRemoteShareFetchRequest();
             } else if (pendingRemoteFetchesOpt.isPresent()) {
+                if (maybeRegisterCallbackPendingRemoteFetch()) {
+                    log.trace("Registered remote storage fetch callback for 
group {}, member {}, "
+                            + "topic partitions {}", shareFetch.groupId(), 
shareFetch.memberId(),
+                        partitionsAcquired.keySet());
+                    return;
+                }
                 completeRemoteStorageShareFetchRequest();
             } else {
                 completeLocalLogShareFetchRequest();
@@ -626,6 +643,16 @@ public class DelayedShareFetch extends DelayedOperation {
         return pendingRemoteFetchesOpt.orElse(null);
     }
 
+    // Visible for testing.
+    boolean outsidePurgatoryCallbackLock() {
+        return outsidePurgatoryCallbackLock.get();
+    }
+
+    // Only used for testing purpose.
+    void updatePartitionsAcquired(LinkedHashMap<TopicIdPartition, Long> 
partitionsAcquired) {
+        this.partitionsAcquired = partitionsAcquired;
+    }
+
     // Visible for testing.
     Meter expiredRequestMeter() {
         return expiredRequestMeter;
@@ -666,6 +693,28 @@ public class DelayedShareFetch extends DelayedOperation {
         return maybeCompletePendingRemoteFetch();
     }
 
+    private boolean maybeRegisterCallbackPendingRemoteFetch() {
+        log.trace("Registering callback pending remote fetch");
+        PendingRemoteFetches pendingFetch = pendingRemoteFetchesOpt.get();
+        if (!pendingFetch.isDone() && shareFetch.fetchParams().maxWaitMs < 
remoteFetchMaxWaitMs) {
+            TimerTask timerTask = new PendingRemoteFetchTimerTask();
+            pendingFetch.invokeCallbackOnCompletion(((ignored, throwable) -> {
+                timerTask.cancel();
+                log.trace("Invoked remote storage fetch callback for group {}, 
member {}, "
+                        + "topic partitions {}", shareFetch.groupId(), 
shareFetch.memberId(),
+                    partitionsAcquired.keySet());
+                if (throwable != null) {
+                    log.error("Remote storage fetch failed for group {}, 
member {}, topic partitions {}",
+                        shareFetch.groupId(), shareFetch.memberId(), 
sharePartitions.keySet(), throwable);
+                }
+                completeRemoteShareFetchRequestOutsidePurgatory();
+            }));
+            replicaManager.addShareFetchTimerRequest(timerTask);
+            return true;
+        }
+        return false;
+    }
+
     /**
      * Throws an exception if a task for remote storage fetch could not be 
scheduled successfully else updates pendingRemoteFetchesOpt.
      * @param remoteStorageFetchInfoMap - The remote storage fetch information.
@@ -904,4 +953,26 @@ public class DelayedShareFetch extends DelayedOperation {
         }
         return completedByMe;
     }
+
+    private void completeRemoteShareFetchRequestOutsidePurgatory() {
+        if (outsidePurgatoryCallbackLock.compareAndSet(false, true)) {
+            completeRemoteStorageShareFetchRequest();
+        }
+    }
+
+    private class PendingRemoteFetchTimerTask extends TimerTask {
+
+        public PendingRemoteFetchTimerTask() {
+            super(remoteFetchMaxWaitMs - shareFetch.fetchParams().maxWaitMs);
+        }
+
+        @Override
+        public void run() {
+            log.trace("Expired remote storage fetch callback for group {}, 
member {}, "
+                    + "topic partitions {}", shareFetch.groupId(), 
shareFetch.memberId(),
+                partitionsAcquired.keySet());
+            expiredRequestMeter.mark();
+            completeRemoteShareFetchRequestOutsidePurgatory();
+        }
+    }
 }
diff --git a/core/src/main/java/kafka/server/share/PendingRemoteFetches.java 
b/core/src/main/java/kafka/server/share/PendingRemoteFetches.java
index 2eb92672dc5..c3ac9c3b553 100644
--- a/core/src/main/java/kafka/server/share/PendingRemoteFetches.java
+++ b/core/src/main/java/kafka/server/share/PendingRemoteFetches.java
@@ -23,10 +23,12 @@ import 
org.apache.kafka.storage.internals.log.LogOffsetMetadata;
 import org.apache.kafka.storage.internals.log.RemoteLogReadResult;
 import org.apache.kafka.storage.internals.log.RemoteStorageFetchInfo;
 
+import java.util.ArrayList;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.Future;
+import java.util.function.BiConsumer;
 
 /**
  * This class is used to store the remote storage fetch information for topic 
partitions in a share fetch request.
@@ -48,6 +50,12 @@ public class PendingRemoteFetches {
         return true;
     }
 
+    public void invokeCallbackOnCompletion(BiConsumer<Void, Throwable> 
callback) {
+        List<CompletableFuture<RemoteLogReadResult>> remoteFetchResult = new 
ArrayList<>();
+        remoteFetches.forEach(remoteFetch -> 
remoteFetchResult.add(remoteFetch.remoteFetchResult()));
+        CompletableFuture.allOf(remoteFetchResult.toArray(new 
CompletableFuture<?>[0])).whenComplete(callback);
+    }
+
     public List<RemoteFetch> remoteFetches() {
         return remoteFetches;
     }
diff --git a/core/src/main/java/kafka/server/share/SharePartitionManager.java 
b/core/src/main/java/kafka/server/share/SharePartitionManager.java
index 1c7edd1b0af..5f0bf1fa239 100644
--- a/core/src/main/java/kafka/server/share/SharePartitionManager.java
+++ b/core/src/main/java/kafka/server/share/SharePartitionManager.java
@@ -130,6 +130,10 @@ public class SharePartitionManager implements 
AutoCloseable {
      * The max delivery count is the maximum number of times a message can be 
delivered before it is considered to be archived.
      */
     private final int maxDeliveryCount;
+    /**
+     * The max wait time for a share fetch request having remote storage fetch.
+     */
+    private final long remoteFetchMaxWaitMs;
 
     /**
      * The persister is used to persist the share partition state.
@@ -153,6 +157,7 @@ public class SharePartitionManager implements AutoCloseable 
{
         int defaultRecordLockDurationMs,
         int maxDeliveryCount,
         int maxInFlightMessages,
+        long remoteFetchMaxWaitMs,
         Persister persister,
         GroupConfigManager groupConfigManager,
         BrokerTopicStats brokerTopicStats
@@ -164,6 +169,7 @@ public class SharePartitionManager implements AutoCloseable 
{
             defaultRecordLockDurationMs,
             maxDeliveryCount,
             maxInFlightMessages,
+            remoteFetchMaxWaitMs,
             persister,
             groupConfigManager,
             new ShareGroupMetrics(time),
@@ -179,6 +185,7 @@ public class SharePartitionManager implements AutoCloseable 
{
         int defaultRecordLockDurationMs,
         int maxDeliveryCount,
         int maxInFlightMessages,
+        long remoteFetchMaxWaitMs,
         Persister persister,
         GroupConfigManager groupConfigManager,
         ShareGroupMetrics shareGroupMetrics,
@@ -193,6 +200,7 @@ public class SharePartitionManager implements AutoCloseable 
{
                 new SystemTimer("share-group-lock-timeout")),
             maxDeliveryCount,
             maxInFlightMessages,
+            remoteFetchMaxWaitMs,
             persister,
             groupConfigManager,
             shareGroupMetrics,
@@ -210,6 +218,7 @@ public class SharePartitionManager implements AutoCloseable 
{
             Timer timer,
             int maxDeliveryCount,
             int maxInFlightMessages,
+            long remoteFetchMaxWaitMs,
             Persister persister,
             GroupConfigManager groupConfigManager,
             ShareGroupMetrics shareGroupMetrics,
@@ -223,6 +232,7 @@ public class SharePartitionManager implements AutoCloseable 
{
         this.timer = timer;
         this.maxDeliveryCount = maxDeliveryCount;
         this.maxInFlightMessages = maxInFlightMessages;
+        this.remoteFetchMaxWaitMs = remoteFetchMaxWaitMs;
         this.persister = persister;
         this.groupConfigManager = groupConfigManager;
         this.shareGroupMetrics = shareGroupMetrics;
@@ -683,7 +693,7 @@ public class SharePartitionManager implements AutoCloseable 
{
         // Add the share fetch to the delayed share fetch purgatory to process 
the fetch request.
         // The request will be added irrespective of whether the share 
partition is initialized or not.
         // Once the share partition is initialized, the delayed share fetch 
will be completed.
-        addDelayedShareFetch(new DelayedShareFetch(shareFetch, replicaManager, 
fencedSharePartitionHandler(), sharePartitions, shareGroupMetrics, time), 
delayedShareFetchWatchKeys);
+        addDelayedShareFetch(new DelayedShareFetch(shareFetch, replicaManager, 
fencedSharePartitionHandler(), sharePartitions, shareGroupMetrics, time, 
remoteFetchMaxWaitMs), delayedShareFetchWatchKeys);
     }
 
     private SharePartition getOrCreateSharePartition(SharePartitionKey 
sharePartitionKey) {
diff --git a/core/src/main/scala/kafka/server/BrokerServer.scala 
b/core/src/main/scala/kafka/server/BrokerServer.scala
index b4be10656e2..61892ea66c3 100644
--- a/core/src/main/scala/kafka/server/BrokerServer.scala
+++ b/core/src/main/scala/kafka/server/BrokerServer.scala
@@ -441,6 +441,7 @@ class BrokerServer(
         config.shareGroupConfig.shareGroupRecordLockDurationMs,
         config.shareGroupConfig.shareGroupDeliveryCountLimit,
         config.shareGroupConfig.shareGroupPartitionMaxRecordLocks,
+        config.remoteLogManagerConfig.remoteFetchMaxWaitMs().toLong,
         persister,
         groupConfigManager,
         brokerTopicStats
diff --git a/core/src/test/java/kafka/server/share/DelayedShareFetchTest.java 
b/core/src/test/java/kafka/server/share/DelayedShareFetchTest.java
index 1ec8cccffa9..498047b890a 100644
--- a/core/src/test/java/kafka/server/share/DelayedShareFetchTest.java
+++ b/core/src/test/java/kafka/server/share/DelayedShareFetchTest.java
@@ -48,6 +48,7 @@ import org.apache.kafka.server.util.MockTime;
 import org.apache.kafka.server.util.timer.SystemTimer;
 import org.apache.kafka.server.util.timer.SystemTimerReaper;
 import org.apache.kafka.server.util.timer.Timer;
+import org.apache.kafka.server.util.timer.TimerTask;
 import org.apache.kafka.storage.internals.log.FetchDataInfo;
 import org.apache.kafka.storage.internals.log.LogOffsetMetadata;
 import org.apache.kafka.storage.internals.log.LogOffsetSnapshot;
@@ -81,6 +82,7 @@ import scala.jdk.javaapi.CollectionConverters;
 
 import static kafka.server.share.PendingRemoteFetches.RemoteFetch;
 import static 
kafka.server.share.SharePartitionManagerTest.DELAYED_SHARE_FETCH_PURGATORY_PURGE_INTERVAL;
+import static 
kafka.server.share.SharePartitionManagerTest.REMOTE_FETCH_MAX_WAIT_MS;
 import static kafka.server.share.SharePartitionManagerTest.buildLogReadResult;
 import static 
kafka.server.share.SharePartitionManagerTest.mockReplicaManagerDelayedShareFetch;
 import static 
org.apache.kafka.server.share.fetch.ShareFetchTestUtils.createShareAcquiredRecords;
@@ -1427,6 +1429,13 @@ public class DelayedShareFetchTest {
         
when(replicaManager.remoteLogManager()).thenReturn(Option.apply(remoteLogManager));
         
when(replicaManager.getPartitionOrException(tp2.topicPartition())).thenThrow(mock(KafkaStorageException.class));
 
+        // Mock the behaviour of replica manager such that remote storage 
fetch completion timer task completes on adding it to the watch queue.
+        doAnswer(invocationOnMock -> {
+            TimerTask timerTask = invocationOnMock.getArgument(0);
+            timerTask.run();
+            return null;
+        }).when(replicaManager).addShareFetchTimerRequest(any());
+
         Uuid fetchId = Uuid.randomUuid();
         DelayedShareFetch delayedShareFetch = 
spy(DelayedShareFetchBuilder.builder()
             .withShareFetchData(shareFetch)
@@ -1777,6 +1786,165 @@ public class DelayedShareFetchTest {
         delayedShareFetch.lock().unlock();
     }
 
+    @Test
+    public void 
testRemoteStorageFetchCompletionPostRegisteringCallbackByPendingFetchesCompletion()
 {
+        ReplicaManager replicaManager = mock(ReplicaManager.class);
+        TopicIdPartition tp0 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 0));
+        SharePartition sp0 = mock(SharePartition.class);
+
+        when(sp0.canAcquireRecords()).thenReturn(true);
+        when(sp0.nextFetchOffset()).thenReturn(10L);
+
+        LinkedHashMap<TopicIdPartition, SharePartition> sharePartitions = new 
LinkedHashMap<>();
+        sharePartitions.put(tp0, sp0);
+
+        CompletableFuture<Map<TopicIdPartition, 
ShareFetchResponseData.PartitionData>> future = new CompletableFuture<>();
+        ShareFetch shareFetch = new ShareFetch(FETCH_PARAMS, "grp", 
Uuid.randomUuid().toString(),
+            future, List.of(tp0), BATCH_SIZE, MAX_FETCH_RECORDS,
+            BROKER_TOPIC_STATS);
+
+        PendingRemoteFetches pendingRemoteFetches = 
mock(PendingRemoteFetches.class);
+        Uuid fetchId = Uuid.randomUuid();
+        DelayedShareFetch delayedShareFetch = 
spy(DelayedShareFetchBuilder.builder()
+            .withShareFetchData(shareFetch)
+            .withReplicaManager(replicaManager)
+            .withSharePartitions(sharePartitions)
+            
.withPartitionMaxBytesStrategy(PartitionMaxBytesStrategy.type(PartitionMaxBytesStrategy.StrategyType.UNIFORM))
+            .withPendingRemoteFetches(pendingRemoteFetches)
+            .withFetchId(fetchId)
+            .build());
+
+        LinkedHashMap<TopicIdPartition, Long> partitionsAcquired = new 
LinkedHashMap<>();
+        partitionsAcquired.put(tp0, 10L);
+
+        // Manually update acquired partitions.
+        delayedShareFetch.updatePartitionsAcquired(partitionsAcquired);
+
+        // Mock remote fetch result.
+        RemoteFetch remoteFetch = mock(RemoteFetch.class);
+        when(remoteFetch.topicIdPartition()).thenReturn(tp0);
+        
when(remoteFetch.remoteFetchResult()).thenReturn(CompletableFuture.completedFuture(
+            new RemoteLogReadResult(Optional.of(REMOTE_FETCH_INFO), 
Optional.empty()))
+        );
+        when(remoteFetch.logReadResult()).thenReturn(new LogReadResult(
+            REMOTE_FETCH_INFO,
+            Option.empty(),
+            -1L,
+            -1L,
+            -1L,
+            -1L,
+            -1L,
+            Option.empty(),
+            Option.empty(),
+            Option.empty()
+        ));
+        
when(pendingRemoteFetches.remoteFetches()).thenReturn(List.of(remoteFetch));
+        when(pendingRemoteFetches.isDone()).thenReturn(false);
+
+        // Make sure that the callback is called to complete remote storage 
share fetch result.
+        doAnswer(invocationOnMock -> {
+            BiConsumer<Void, Throwable> callback = 
invocationOnMock.getArgument(0);
+            callback.accept(mock(Void.class), null);
+            return null;
+        }).when(pendingRemoteFetches).invokeCallbackOnCompletion(any());
+
+        when(sp0.acquire(anyString(), anyInt(), anyInt(), anyLong(), 
any(FetchPartitionData.class), any())).thenReturn(
+            createShareAcquiredRecords(new 
ShareFetchResponseData.AcquiredRecords().setFirstOffset(0).setLastOffset(3).setDeliveryCount((short)
 1)));
+
+        assertFalse(delayedShareFetch.isCompleted());
+        delayedShareFetch.forceComplete();
+        assertTrue(delayedShareFetch.isCompleted());
+        // the future of shareFetch completes.
+        assertTrue(shareFetch.isCompleted());
+        assertEquals(Set.of(tp0), future.join().keySet());
+        // Verify the locks are released for tp0.
+        Mockito.verify(delayedShareFetch, 
times(1)).releasePartitionLocks(Set.of(tp0));
+        assertTrue(delayedShareFetch.outsidePurgatoryCallbackLock());
+        assertTrue(delayedShareFetch.lock().tryLock());
+        delayedShareFetch.lock().unlock();
+    }
+
+    @Test
+    public void 
testRemoteStorageFetchCompletionPostRegisteringCallbackByTimerTaskCompletion() {
+        ReplicaManager replicaManager = mock(ReplicaManager.class);
+        TopicIdPartition tp0 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 0));
+        SharePartition sp0 = mock(SharePartition.class);
+
+        when(sp0.canAcquireRecords()).thenReturn(true);
+        when(sp0.nextFetchOffset()).thenReturn(10L);
+
+        LinkedHashMap<TopicIdPartition, SharePartition> sharePartitions = new 
LinkedHashMap<>();
+        sharePartitions.put(tp0, sp0);
+
+        CompletableFuture<Map<TopicIdPartition, 
ShareFetchResponseData.PartitionData>> future = new CompletableFuture<>();
+        ShareFetch shareFetch = new ShareFetch(FETCH_PARAMS, "grp", 
Uuid.randomUuid().toString(),
+            future, List.of(tp0), BATCH_SIZE, MAX_FETCH_RECORDS,
+            BROKER_TOPIC_STATS);
+
+        PendingRemoteFetches pendingRemoteFetches = 
mock(PendingRemoteFetches.class);
+        Uuid fetchId = Uuid.randomUuid();
+        DelayedShareFetch delayedShareFetch = 
spy(DelayedShareFetchBuilder.builder()
+            .withShareFetchData(shareFetch)
+            .withReplicaManager(replicaManager)
+            .withSharePartitions(sharePartitions)
+            
.withPartitionMaxBytesStrategy(PartitionMaxBytesStrategy.type(PartitionMaxBytesStrategy.StrategyType.UNIFORM))
+            .withPendingRemoteFetches(pendingRemoteFetches)
+            .withFetchId(fetchId)
+            .build());
+
+        LinkedHashMap<TopicIdPartition, Long> partitionsAcquired = new 
LinkedHashMap<>();
+        partitionsAcquired.put(tp0, 10L);
+
+        // Manually update acquired partitions.
+        delayedShareFetch.updatePartitionsAcquired(partitionsAcquired);
+
+        // Mock remote fetch result.
+        RemoteFetch remoteFetch = mock(RemoteFetch.class);
+        when(remoteFetch.topicIdPartition()).thenReturn(tp0);
+        
when(remoteFetch.remoteFetchResult()).thenReturn(CompletableFuture.completedFuture(
+            new RemoteLogReadResult(Optional.of(REMOTE_FETCH_INFO), 
Optional.empty()))
+        );
+        when(remoteFetch.logReadResult()).thenReturn(new LogReadResult(
+            REMOTE_FETCH_INFO,
+            Option.empty(),
+            -1L,
+            -1L,
+            -1L,
+            -1L,
+            -1L,
+            Option.empty(),
+            Option.empty(),
+            Option.empty()
+        ));
+        
when(pendingRemoteFetches.remoteFetches()).thenReturn(List.of(remoteFetch));
+        when(pendingRemoteFetches.isDone()).thenReturn(false);
+
+        // Make sure that the callback to complete remote storage share fetch 
result is not called.
+        doAnswer(invocationOnMock -> 
null).when(pendingRemoteFetches).invokeCallbackOnCompletion(any());
+
+        // Mock the behaviour of replica manager such that remote storage 
fetch completion timer task completes on adding it to the watch queue.
+        doAnswer(invocationOnMock -> {
+            TimerTask timerTask = invocationOnMock.getArgument(0);
+            timerTask.run();
+            return null;
+        }).when(replicaManager).addShareFetchTimerRequest(any());
+
+        when(sp0.acquire(anyString(), anyInt(), anyInt(), anyLong(), 
any(FetchPartitionData.class), any())).thenReturn(
+            createShareAcquiredRecords(new 
ShareFetchResponseData.AcquiredRecords().setFirstOffset(0).setLastOffset(3).setDeliveryCount((short)
 1)));
+
+        assertFalse(delayedShareFetch.isCompleted());
+        delayedShareFetch.forceComplete();
+        assertTrue(delayedShareFetch.isCompleted());
+        // the future of shareFetch completes.
+        assertTrue(shareFetch.isCompleted());
+        assertEquals(Set.of(tp0), future.join().keySet());
+        // Verify the locks are released for tp0.
+        Mockito.verify(delayedShareFetch, 
times(1)).releasePartitionLocks(Set.of(tp0));
+        assertTrue(delayedShareFetch.outsidePurgatoryCallbackLock());
+        assertTrue(delayedShareFetch.lock().tryLock());
+        delayedShareFetch.lock().unlock();
+    }
+
     static void mockTopicIdPartitionToReturnDataEqualToMinBytes(ReplicaManager 
replicaManager, TopicIdPartition topicIdPartition, int minBytes) {
         LogOffsetMetadata hwmOffsetMetadata = new LogOffsetMetadata(1, 1, 
minBytes);
         LogOffsetSnapshot endOffsetSnapshot = new LogOffsetSnapshot(1, 
mock(LogOffsetMetadata.class),
@@ -1847,7 +2015,7 @@ public class DelayedShareFetchTest {
         private LinkedHashMap<TopicIdPartition, SharePartition> 
sharePartitions = mock(LinkedHashMap.class);
         private PartitionMaxBytesStrategy partitionMaxBytesStrategy = 
mock(PartitionMaxBytesStrategy.class);
         private Time time = new MockTime();
-        private final Optional<PendingRemoteFetches> pendingRemoteFetches = 
Optional.empty();
+        private Optional<PendingRemoteFetches> pendingRemoteFetches = 
Optional.empty();
         private ShareGroupMetrics shareGroupMetrics = 
mock(ShareGroupMetrics.class);
         private Uuid fetchId = Uuid.randomUuid();
 
@@ -1886,6 +2054,11 @@ public class DelayedShareFetchTest {
             return this;
         }
 
+        private DelayedShareFetchBuilder 
withPendingRemoteFetches(PendingRemoteFetches pendingRemoteFetches) {
+            this.pendingRemoteFetches = Optional.of(pendingRemoteFetches);
+            return this;
+        }
+
         private DelayedShareFetchBuilder withFetchId(Uuid fetchId) {
             this.fetchId = fetchId;
             return this;
@@ -1905,7 +2078,8 @@ public class DelayedShareFetchTest {
                 shareGroupMetrics,
                 time,
                 pendingRemoteFetches,
-                fetchId);
+                fetchId,
+                REMOTE_FETCH_MAX_WAIT_MS);
         }
     }
 }
diff --git 
a/core/src/test/java/kafka/server/share/SharePartitionManagerTest.java 
b/core/src/test/java/kafka/server/share/SharePartitionManagerTest.java
index 0f821d4423a..d4c68fe555b 100644
--- a/core/src/test/java/kafka/server/share/SharePartitionManagerTest.java
+++ b/core/src/test/java/kafka/server/share/SharePartitionManagerTest.java
@@ -158,6 +158,7 @@ public class SharePartitionManagerTest {
     private static final String CONNECTION_ID = "id-1";
 
     static final int DELAYED_SHARE_FETCH_PURGATORY_PURGE_INTERVAL = 1000;
+    static final long REMOTE_FETCH_MAX_WAIT_MS = 6000L;
 
     private MockTime time;
     private ReplicaManager mockReplicaManager;
@@ -3242,6 +3243,7 @@ public class SharePartitionManagerTest {
                 timer,
                 MAX_DELIVERY_COUNT,
                 MAX_IN_FLIGHT_MESSAGES,
+                REMOTE_FETCH_MAX_WAIT_MS,
                 persister,
                 mock(GroupConfigManager.class),
                 shareGroupMetrics,
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala 
b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
index d33a68a8348..f9782e713ee 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
@@ -6123,7 +6123,8 @@ class ReplicaManagerTest {
         mock(classOf[BiConsumer[SharePartitionKey, Throwable]]),
         sharePartitions,
         mock(classOf[ShareGroupMetrics]),
-        time))
+        time,
+        500))
 
       val delayedShareFetchWatchKeys : util.List[DelayedShareFetchKey] = new 
util.ArrayList[DelayedShareFetchKey]
       topicPartitions.forEach((topicIdPartition: TopicIdPartition) => 
delayedShareFetchWatchKeys.add(new DelayedShareFetchGroupKey(groupId, 
topicIdPartition.topicId, topicIdPartition.partition)))

Reply via email to