This is an automated email from the ASF dual-hosted git repository.

schofielaj pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new c7619ef8d14 KAFKA-17951: Share parition rotate strategy (#18651)
c7619ef8d14 is described below

commit c7619ef8d1446f7f615f3cc66b2ae5d6850826b9
Author: Apoorv Mittal <[email protected]>
AuthorDate: Tue Jan 28 11:44:48 2025 +0000

    KAFKA-17951: Share parition rotate strategy (#18651)
    
    Reviewers: Andrew Schofield <[email protected]>, Abhinav Dixit 
<[email protected]>
---
 .../kafka/server/share/SharePartitionManager.java  |  11 +-
 core/src/main/scala/kafka/server/KafkaApis.scala   |   6 +-
 .../kafka/server/share/DelayedShareFetchTest.java  |  70 +++------
 .../kafka/server/share/ShareFetchUtilsTest.java    |  19 +--
 .../server/share/SharePartitionManagerTest.java    | 174 ++++++++++++---------
 .../scala/unit/kafka/server/KafkaApisTest.scala    |  36 +++--
 .../share/fetch/PartitionRotateStrategy.java       | 106 +++++++++++++
 .../kafka/server/share/fetch/ShareFetch.java       |   7 +-
 .../share/fetch/PartitionRotateStrategyTest.java   |  99 ++++++++++++
 .../kafka/server/share/fetch/ShareFetchTest.java   |  19 ++-
 .../server/share/fetch/ShareFetchTestUtils.java    |  77 +++++++++
 11 files changed, 455 insertions(+), 169 deletions(-)

diff --git a/core/src/main/java/kafka/server/share/SharePartitionManager.java 
b/core/src/main/java/kafka/server/share/SharePartitionManager.java
index dceca9d9cf7..08fb3472cd1 100644
--- a/core/src/main/java/kafka/server/share/SharePartitionManager.java
+++ b/core/src/main/java/kafka/server/share/SharePartitionManager.java
@@ -50,6 +50,8 @@ import 
org.apache.kafka.server.share.context.ShareSessionContext;
 import org.apache.kafka.server.share.fetch.DelayedShareFetchGroupKey;
 import org.apache.kafka.server.share.fetch.DelayedShareFetchKey;
 import org.apache.kafka.server.share.fetch.DelayedShareFetchPartitionKey;
+import org.apache.kafka.server.share.fetch.PartitionRotateStrategy;
+import 
org.apache.kafka.server.share.fetch.PartitionRotateStrategy.PartitionRotateMetadata;
 import org.apache.kafka.server.share.fetch.ShareFetch;
 import org.apache.kafka.server.share.persister.Persister;
 import org.apache.kafka.server.share.session.ShareSession;
@@ -261,14 +263,19 @@ public class SharePartitionManager implements 
AutoCloseable {
         String groupId,
         String memberId,
         FetchParams fetchParams,
+        int sessionEpoch,
         int batchSize,
-        Map<TopicIdPartition, Integer> partitionMaxBytes
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes
     ) {
         log.trace("Fetch request for topicIdPartitions: {} with groupId: {} 
fetch params: {}",
                 partitionMaxBytes.keySet(), groupId, fetchParams);
 
+        LinkedHashMap<TopicIdPartition, Integer> topicIdPartitions = 
PartitionRotateStrategy
+            .type(PartitionRotateStrategy.StrategyType.ROUND_ROBIN)
+            .rotate(partitionMaxBytes, new 
PartitionRotateMetadata(sessionEpoch));
+
         CompletableFuture<Map<TopicIdPartition, PartitionData>> future = new 
CompletableFuture<>();
-        processShareFetch(new ShareFetch(fetchParams, groupId, memberId, 
future, partitionMaxBytes, batchSize, maxFetchRecords, brokerTopicStats));
+        processShareFetch(new ShareFetch(fetchParams, groupId, memberId, 
future, topicIdPartitions, batchSize, maxFetchRecords, brokerTopicStats));
 
         return future;
     }
diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala 
b/core/src/main/scala/kafka/server/KafkaApis.scala
index 82132d868bf..18795a7e0f3 100644
--- a/core/src/main/scala/kafka/server/KafkaApis.scala
+++ b/core/src/main/scala/kafka/server/KafkaApis.scala
@@ -2796,9 +2796,9 @@ class KafkaApis(val requestChannel: RequestChannel,
 
     // Handling the Fetch from the ShareFetchRequest.
     // Variable to store the topic partition wise result of fetching.
-    val fetchResult: CompletableFuture[Map[TopicIdPartition, 
ShareFetchResponseData.PartitionData]] =
-      handleFetchFromShareFetchRequest(
+    val fetchResult: CompletableFuture[Map[TopicIdPartition, 
ShareFetchResponseData.PartitionData]] = handleFetchFromShareFetchRequest(
       request,
+      shareSessionEpoch,
       erroneousAndValidPartitionData,
       sharePartitionManager,
       authorizedTopics
@@ -2893,6 +2893,7 @@ class KafkaApis(val requestChannel: RequestChannel,
 
   // Visible for Testing
   def handleFetchFromShareFetchRequest(request: RequestChannel.Request,
+                                       shareSessionEpoch: Int,
                                        erroneousAndValidPartitionData: 
ErroneousAndValidPartitionData,
                                        sharePartitionManagerInstance: 
SharePartitionManager,
                                        authorizedTopics: Set[String]
@@ -2954,6 +2955,7 @@ class KafkaApis(val requestChannel: RequestChannel,
         groupId,
         shareFetchRequest.data.memberId,
         params,
+        shareSessionEpoch,
         shareFetchRequest.data.batchSize,
         interestedWithMaxBytes
       ).thenApply{ result =>
diff --git a/core/src/test/java/kafka/server/share/DelayedShareFetchTest.java 
b/core/src/test/java/kafka/server/share/DelayedShareFetchTest.java
index 4634a73e8de..6068fbb769d 100644
--- a/core/src/test/java/kafka/server/share/DelayedShareFetchTest.java
+++ b/core/src/test/java/kafka/server/share/DelayedShareFetchTest.java
@@ -54,7 +54,6 @@ import org.mockito.Mockito;
 
 import java.util.ArrayList;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.LinkedHashMap;
 import java.util.LinkedHashSet;
 import java.util.List;
@@ -72,6 +71,7 @@ import static 
kafka.server.share.SharePartitionManagerTest.DELAYED_SHARE_FETCH_P
 import static kafka.server.share.SharePartitionManagerTest.PARTITION_MAX_BYTES;
 import static kafka.server.share.SharePartitionManagerTest.buildLogReadResult;
 import static 
kafka.server.share.SharePartitionManagerTest.mockReplicaManagerDelayedShareFetch;
+import static 
org.apache.kafka.server.share.fetch.ShareFetchTestUtils.orderedMap;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -115,9 +115,7 @@ public class DelayedShareFetchTest {
         Uuid topicId = Uuid.randomUuid();
         TopicIdPartition tp0 = new TopicIdPartition(topicId, new 
TopicPartition("foo", 0));
         TopicIdPartition tp1 = new TopicIdPartition(topicId, new 
TopicPartition("foo", 1));
-        Map<TopicIdPartition, Integer> partitionMaxBytes = new HashMap<>();
-        partitionMaxBytes.put(tp0, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp1, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0, tp1);
 
         SharePartition sp0 = mock(SharePartition.class);
         SharePartition sp1 = mock(SharePartition.class);
@@ -155,9 +153,7 @@ public class DelayedShareFetchTest {
         ReplicaManager replicaManager = mock(ReplicaManager.class);
         TopicIdPartition tp0 = new TopicIdPartition(topicId, new 
TopicPartition("foo", 0));
         TopicIdPartition tp1 = new TopicIdPartition(topicId, new 
TopicPartition("foo", 1));
-        Map<TopicIdPartition, Integer> partitionMaxBytes = new HashMap<>();
-        partitionMaxBytes.put(tp0, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp1, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0, tp1);
 
         SharePartition sp0 = mock(SharePartition.class);
         SharePartition sp1 = mock(SharePartition.class);
@@ -217,9 +213,7 @@ public class DelayedShareFetchTest {
         ReplicaManager replicaManager = mock(ReplicaManager.class);
         TopicIdPartition tp0 = new TopicIdPartition(topicId, new 
TopicPartition("foo", 0));
         TopicIdPartition tp1 = new TopicIdPartition(topicId, new 
TopicPartition("foo", 1));
-        Map<TopicIdPartition, Integer> partitionMaxBytes = new HashMap<>();
-        partitionMaxBytes.put(tp0, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp1, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0, tp1);
 
         SharePartition sp0 = mock(SharePartition.class);
         SharePartition sp1 = mock(SharePartition.class);
@@ -274,9 +268,7 @@ public class DelayedShareFetchTest {
         ReplicaManager replicaManager = mock(ReplicaManager.class);
         TopicIdPartition tp0 = new TopicIdPartition(topicId, new 
TopicPartition("foo", 0));
         TopicIdPartition tp1 = new TopicIdPartition(topicId, new 
TopicPartition("foo", 1));
-        Map<TopicIdPartition, Integer> partitionMaxBytes = new HashMap<>();
-        partitionMaxBytes.put(tp0, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp1, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0, tp1);
 
         SharePartition sp0 = mock(SharePartition.class);
         SharePartition sp1 = mock(SharePartition.class);
@@ -323,9 +315,7 @@ public class DelayedShareFetchTest {
         ReplicaManager replicaManager = mock(ReplicaManager.class);
         TopicIdPartition tp0 = new TopicIdPartition(topicId, new 
TopicPartition("foo", 0));
         TopicIdPartition tp1 = new TopicIdPartition(topicId, new 
TopicPartition("foo", 1));
-        Map<TopicIdPartition, Integer> partitionMaxBytes = new HashMap<>();
-        partitionMaxBytes.put(tp0, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp1, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0, tp1);
 
         SharePartition sp0 = mock(SharePartition.class);
         SharePartition sp1 = mock(SharePartition.class);
@@ -369,9 +359,7 @@ public class DelayedShareFetchTest {
         ReplicaManager replicaManager = mock(ReplicaManager.class);
         TopicIdPartition tp0 = new TopicIdPartition(topicId, new 
TopicPartition("foo", 0));
         TopicIdPartition tp1 = new TopicIdPartition(topicId, new 
TopicPartition("foo", 1));
-        Map<TopicIdPartition, Integer> partitionMaxBytes = new HashMap<>();
-        partitionMaxBytes.put(tp0, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp1, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0, tp1);
 
         SharePartition sp0 = mock(SharePartition.class);
         SharePartition sp1 = mock(SharePartition.class);
@@ -419,8 +407,7 @@ public class DelayedShareFetchTest {
         Uuid topicId = Uuid.randomUuid();
         ReplicaManager replicaManager = mock(ReplicaManager.class);
         TopicIdPartition tp0 = new TopicIdPartition(topicId, new 
TopicPartition("foo", 0));
-        Map<TopicIdPartition, Integer> partitionMaxBytes = new HashMap<>();
-        partitionMaxBytes.put(tp0, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0);
 
         SharePartition sp0 = mock(SharePartition.class);
 
@@ -469,9 +456,7 @@ public class DelayedShareFetchTest {
         TopicIdPartition tp0 = new TopicIdPartition(topicId, new 
TopicPartition("foo", 0));
         TopicIdPartition tp1 = new TopicIdPartition(topicId, new 
TopicPartition("foo", 1));
         TopicIdPartition tp2 = new TopicIdPartition(topicId, new 
TopicPartition("foo", 2));
-        Map<TopicIdPartition, Integer> partitionMaxBytes1 = new HashMap<>();
-        partitionMaxBytes1.put(tp0, PARTITION_MAX_BYTES);
-        partitionMaxBytes1.put(tp1, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes1 = 
orderedMap(PARTITION_MAX_BYTES, tp0, tp1);
 
         SharePartition sp0 = mock(SharePartition.class);
         SharePartition sp1 = mock(SharePartition.class);
@@ -513,9 +498,7 @@ public class DelayedShareFetchTest {
         assertTrue(delayedShareFetch1.lock().tryLock());
         delayedShareFetch1.lock().unlock();
 
-        Map<TopicIdPartition, Integer> partitionMaxBytes2 = new HashMap<>();
-        partitionMaxBytes2.put(tp1, PARTITION_MAX_BYTES);
-        partitionMaxBytes2.put(tp2, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes2 = 
orderedMap(PARTITION_MAX_BYTES, tp0, tp1);
         ShareFetch shareFetch2 = new ShareFetch(FETCH_PARAMS, groupId, 
Uuid.randomUuid().toString(),
             new CompletableFuture<>(), partitionMaxBytes2, BATCH_SIZE, 
MAX_FETCH_RECORDS,
             BROKER_TOPIC_STATS);
@@ -561,9 +544,7 @@ public class DelayedShareFetchTest {
         ReplicaManager replicaManager = mock(ReplicaManager.class);
         TopicIdPartition tp0 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 0));
         TopicIdPartition tp1 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 1));
-        Map<TopicIdPartition, Integer> partitionMaxBytes = new HashMap<>();
-        partitionMaxBytes.put(tp0, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp1, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0, tp1);
 
         SharePartition sp0 = mock(SharePartition.class);
         SharePartition sp1 = mock(SharePartition.class);
@@ -619,8 +600,7 @@ public class DelayedShareFetchTest {
         Uuid topicId = Uuid.randomUuid();
         ReplicaManager replicaManager = mock(ReplicaManager.class);
         TopicIdPartition tp0 = new TopicIdPartition(topicId, new 
TopicPartition("foo", 0));
-        Map<TopicIdPartition, Integer> partitionMaxBytes = new HashMap<>();
-        partitionMaxBytes.put(tp0, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0);
 
         SharePartition sp0 = mock(SharePartition.class);
 
@@ -697,7 +677,7 @@ public class DelayedShareFetchTest {
         mockTopicIdPartitionToReturnDataEqualToMinBytes(replicaManager, tp0, 
1);
 
         ShareFetch shareFetch = new ShareFetch(FETCH_PARAMS, groupId, 
Uuid.randomUuid().toString(),
-            new CompletableFuture<>(), Map.of(tp0, PARTITION_MAX_BYTES), 
BATCH_SIZE, MAX_FETCH_RECORDS,
+            new CompletableFuture<>(), orderedMap(PARTITION_MAX_BYTES, tp0), 
BATCH_SIZE, MAX_FETCH_RECORDS,
             BROKER_TOPIC_STATS);
 
         DelayedShareFetch delayedShareFetch = 
DelayedShareFetchTest.DelayedShareFetchBuilder.builder()
@@ -728,7 +708,7 @@ public class DelayedShareFetchTest {
         sharePartitions.put(tp0, sp0);
 
         ShareFetch shareFetch = new ShareFetch(FETCH_PARAMS, groupId, 
Uuid.randomUuid().toString(),
-            new CompletableFuture<>(), Map.of(tp0, PARTITION_MAX_BYTES), 
BATCH_SIZE, MAX_FETCH_RECORDS,
+            new CompletableFuture<>(), orderedMap(PARTITION_MAX_BYTES, tp0), 
BATCH_SIZE, MAX_FETCH_RECORDS,
             BROKER_TOPIC_STATS);
 
         DelayedShareFetch delayedShareFetch = 
DelayedShareFetchTest.DelayedShareFetchBuilder.builder()
@@ -747,8 +727,7 @@ public class DelayedShareFetchTest {
         String groupId = "grp";
         TopicIdPartition tp0 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 0));
         SharePartition sp0 = mock(SharePartition.class);
-        Map<TopicIdPartition, Integer> partitionMaxBytes = new HashMap<>();
-        partitionMaxBytes.put(tp0, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0);
 
         when(sp0.maybeAcquireFetchLock()).thenReturn(true);
         when(sp0.canAcquireRecords()).thenReturn(true);
@@ -804,12 +783,7 @@ public class DelayedShareFetchTest {
         SharePartition sp3 = mock(SharePartition.class);
         SharePartition sp4 = mock(SharePartition.class);
 
-        Map<TopicIdPartition, Integer> partitionMaxBytes = new HashMap<>();
-        partitionMaxBytes.put(tp0, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp1, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp2, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp3, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp4, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0, tp1, tp2, tp3, tp4);
 
         when(sp0.maybeAcquireFetchLock()).thenReturn(true);
         when(sp1.maybeAcquireFetchLock()).thenReturn(true);
@@ -907,12 +881,7 @@ public class DelayedShareFetchTest {
         SharePartition sp3 = mock(SharePartition.class);
         SharePartition sp4 = mock(SharePartition.class);
 
-        Map<TopicIdPartition, Integer> partitionMaxBytes = new HashMap<>();
-        partitionMaxBytes.put(tp0, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp1, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp2, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp3, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp4, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0, tp1, tp2, tp3, tp4);
 
         when(sp0.maybeAcquireFetchLock()).thenReturn(true);
         when(sp1.maybeAcquireFetchLock()).thenReturn(true);
@@ -992,10 +961,7 @@ public class DelayedShareFetchTest {
         TopicIdPartition tp0 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 0));
         TopicIdPartition tp1 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 1));
         TopicIdPartition tp2 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 2));
-        Map<TopicIdPartition, Integer> partitionMaxBytes = new HashMap<>();
-        partitionMaxBytes.put(tp0, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp1, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp2, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0, tp1, tp2);
 
         SharePartition sp0 = mock(SharePartition.class);
         SharePartition sp1 = mock(SharePartition.class);
diff --git a/core/src/test/java/kafka/server/share/ShareFetchUtilsTest.java 
b/core/src/test/java/kafka/server/share/ShareFetchUtilsTest.java
index 4de38435878..a27b65ecd2f 100644
--- a/core/src/test/java/kafka/server/share/ShareFetchUtilsTest.java
+++ b/core/src/test/java/kafka/server/share/ShareFetchUtilsTest.java
@@ -53,6 +53,7 @@ import java.util.concurrent.CompletableFuture;
 import java.util.function.BiConsumer;
 
 import static kafka.server.share.SharePartitionManagerTest.PARTITION_MAX_BYTES;
+import static 
org.apache.kafka.server.share.fetch.ShareFetchTestUtils.orderedMap;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertNull;
 import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -84,9 +85,7 @@ public class ShareFetchUtilsTest {
         String memberId = Uuid.randomUuid().toString();
         TopicIdPartition tp0 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 0));
         TopicIdPartition tp1 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 1));
-        Map<TopicIdPartition, Integer> partitionMaxBytes = new HashMap<>();
-        partitionMaxBytes.put(tp0, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp1, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0, tp1);
 
         SharePartition sp0 = mock(SharePartition.class);
         SharePartition sp1 = mock(SharePartition.class);
@@ -151,9 +150,7 @@ public class ShareFetchUtilsTest {
         String memberId = Uuid.randomUuid().toString();
         TopicIdPartition tp0 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 0));
         TopicIdPartition tp1 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 1));
-        Map<TopicIdPartition, Integer> partitionMaxBytes = new HashMap<>();
-        partitionMaxBytes.put(tp0, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp1, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0, tp1);
 
         SharePartition sp0 = mock(SharePartition.class);
         SharePartition sp1 = mock(SharePartition.class);
@@ -199,9 +196,7 @@ public class ShareFetchUtilsTest {
         TopicIdPartition tp0 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 0));
         TopicIdPartition tp1 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 1));
 
-        Map<TopicIdPartition, Integer> partitionMaxBytes = new HashMap<>();
-        partitionMaxBytes.put(tp0, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp1, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0, tp1);
 
         SharePartition sp0 = Mockito.mock(SharePartition.class);
         SharePartition sp1 = Mockito.mock(SharePartition.class);
@@ -295,7 +290,7 @@ public class ShareFetchUtilsTest {
         String groupId = "grp";
 
         TopicIdPartition tp0 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 0));
-        Map<TopicIdPartition, Integer> partitionMaxBytes = 
Collections.singletonMap(tp0, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0);
 
         SharePartition sp0 = Mockito.mock(SharePartition.class);
         LinkedHashMap<TopicIdPartition, SharePartition> sharePartitions = new 
LinkedHashMap<>();
@@ -357,9 +352,7 @@ public class ShareFetchUtilsTest {
         TopicIdPartition tp0 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 0));
         TopicIdPartition tp1 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 1));
 
-        Map<TopicIdPartition, Integer> partitionMaxBytes = new HashMap<>();
-        partitionMaxBytes.put(tp0, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp1, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0, tp1);
 
         SharePartition sp0 = Mockito.mock(SharePartition.class);
         SharePartition sp1 = Mockito.mock(SharePartition.class);
diff --git 
a/core/src/test/java/kafka/server/share/SharePartitionManagerTest.java 
b/core/src/test/java/kafka/server/share/SharePartitionManagerTest.java
index 0988d59b5ec..6786e192ea5 100644
--- a/core/src/test/java/kafka/server/share/SharePartitionManagerTest.java
+++ b/core/src/test/java/kafka/server/share/SharePartitionManagerTest.java
@@ -91,6 +91,7 @@ import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.Timeout;
+import org.mockito.ArgumentCaptor;
 import org.mockito.ArgumentMatchers;
 import org.mockito.Mockito;
 
@@ -117,6 +118,8 @@ import scala.collection.Seq;
 import scala.jdk.javaapi.CollectionConverters;
 
 import static 
kafka.server.share.DelayedShareFetchTest.mockTopicIdPartitionToReturnDataEqualToMinBytes;
+import static 
org.apache.kafka.server.share.fetch.ShareFetchTestUtils.orderedMap;
+import static 
org.apache.kafka.server.share.fetch.ShareFetchTestUtils.validateRotatedMapEquals;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertInstanceOf;
@@ -1048,14 +1051,7 @@ public class SharePartitionManagerTest {
         TopicIdPartition tp4 = new TopicIdPartition(fooId, new 
TopicPartition("foo", 2));
         TopicIdPartition tp5 = new TopicIdPartition(barId, new 
TopicPartition("bar", 2));
         TopicIdPartition tp6 = new TopicIdPartition(fooId, new 
TopicPartition("foo", 3));
-        Map<TopicIdPartition, Integer> partitionMaxBytes = new HashMap<>();
-        partitionMaxBytes.put(tp0, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp1, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp2, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp3, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp4, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp5, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp6, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0, tp1, tp2, tp3, tp4, tp5, tp6);
 
         mockFetchOffsetForTimestamp(mockReplicaManager);
 
@@ -1085,18 +1081,18 @@ public class SharePartitionManagerTest {
         doAnswer(invocation -> 
buildLogReadResult(partitionMaxBytes.keySet())).when(mockReplicaManager).readFromLog(any(),
 any(), any(ReplicaQuota.class), anyBoolean());
 
         CompletableFuture<Map<TopicIdPartition, PartitionData>> future = 
sharePartitionManager.fetchMessages(
-            groupId, memberId1.toString(), FETCH_PARAMS, BATCH_SIZE, 
partitionMaxBytes);
+            groupId, memberId1.toString(), FETCH_PARAMS, 1, BATCH_SIZE, 
partitionMaxBytes);
         assertTrue(future.isDone());
         Mockito.verify(mockReplicaManager, times(1)).readFromLog(
             any(), any(), any(ReplicaQuota.class), anyBoolean());
 
-        future = sharePartitionManager.fetchMessages(groupId, 
memberId1.toString(), FETCH_PARAMS, BATCH_SIZE,
+        future = sharePartitionManager.fetchMessages(groupId, 
memberId1.toString(), FETCH_PARAMS, 3, BATCH_SIZE,
             partitionMaxBytes);
         assertTrue(future.isDone());
         Mockito.verify(mockReplicaManager, times(2)).readFromLog(
             any(), any(), any(ReplicaQuota.class), anyBoolean());
 
-        future = sharePartitionManager.fetchMessages(groupId, 
memberId1.toString(), FETCH_PARAMS, BATCH_SIZE,
+        future = sharePartitionManager.fetchMessages(groupId, 
memberId1.toString(), FETCH_PARAMS, 10, BATCH_SIZE,
             partitionMaxBytes);
         assertTrue(future.isDone());
         Mockito.verify(mockReplicaManager, times(3)).readFromLog(
@@ -1136,11 +1132,7 @@ public class SharePartitionManagerTest {
         TopicIdPartition tp1 = new TopicIdPartition(fooId, new 
TopicPartition("foo", 1));
         TopicIdPartition tp2 = new TopicIdPartition(barId, new 
TopicPartition("bar", 0));
         TopicIdPartition tp3 = new TopicIdPartition(barId, new 
TopicPartition("bar", 1));
-        Map<TopicIdPartition, Integer> partitionMaxBytes = new HashMap<>();
-        partitionMaxBytes.put(tp0, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp1, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp2, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp3, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0, tp1, tp2, tp3);
 
         final Time time = new MockTime(0, System.currentTimeMillis(), 0);
 
@@ -1211,7 +1203,7 @@ public class SharePartitionManagerTest {
         try {
             for (int i = 0; i != threadCount; ++i) {
                 executorService.submit(() -> {
-                    sharePartitionManager.fetchMessages(groupId, 
memberId1.toString(), FETCH_PARAMS,
+                    sharePartitionManager.fetchMessages(groupId, 
memberId1.toString(), FETCH_PARAMS, 0,
                         BATCH_SIZE, partitionMaxBytes);
                 });
                 // We are blocking the main thread at an interval of 10 
threads so that the currently running executorService threads can complete.
@@ -1235,8 +1227,7 @@ public class SharePartitionManagerTest {
         Uuid memberId = Uuid.randomUuid();
         Uuid fooId = Uuid.randomUuid();
         TopicIdPartition tp0 = new TopicIdPartition(fooId, new 
TopicPartition("foo", 0));
-        Map<TopicIdPartition, Integer> partitionMaxBytes = new HashMap<>();
-        partitionMaxBytes.put(tp0, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0);
 
         SharePartition sp0 = mock(SharePartition.class);
         when(sp0.maybeAcquireFetchLock()).thenReturn(true);
@@ -1258,7 +1249,7 @@ public class SharePartitionManagerTest {
             .build();
 
         CompletableFuture<Map<TopicIdPartition, 
ShareFetchResponseData.PartitionData>> future =
-            sharePartitionManager.fetchMessages(groupId, memberId.toString(), 
FETCH_PARAMS,
+            sharePartitionManager.fetchMessages(groupId, memberId.toString(), 
FETCH_PARAMS, 0,
                 BATCH_SIZE, partitionMaxBytes);
         Mockito.verify(mockReplicaManager, times(0)).readFromLog(
             any(), any(), any(ReplicaQuota.class), anyBoolean());
@@ -1279,7 +1270,7 @@ public class SharePartitionManagerTest {
         Uuid memberId = Uuid.randomUuid();
         Uuid fooId = Uuid.randomUuid();
         TopicIdPartition tp0 = new TopicIdPartition(fooId, new 
TopicPartition("foo", 0));
-        Map<TopicIdPartition, Integer> partitionMaxBytes = 
Collections.singletonMap(tp0, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0);
 
         mockFetchOffsetForTimestamp(mockReplicaManager);
 
@@ -1297,7 +1288,7 @@ public class SharePartitionManagerTest {
 
         doAnswer(invocation -> 
buildLogReadResult(partitionMaxBytes.keySet())).when(mockReplicaManager).readFromLog(any(),
 any(), any(ReplicaQuota.class), anyBoolean());
 
-        sharePartitionManager.fetchMessages(groupId, memberId.toString(), 
FETCH_PARAMS, BATCH_SIZE,
+        sharePartitionManager.fetchMessages(groupId, memberId.toString(), 
FETCH_PARAMS, 0, BATCH_SIZE,
             partitionMaxBytes);
         // Since the nextFetchOffset does not point to endOffset + 1, i.e. 
some of the records in the cachedState are AVAILABLE,
         // even though the maxInFlightMessages limit is exceeded, 
replicaManager.readFromLog should be called
@@ -1738,9 +1729,7 @@ public class SharePartitionManagerTest {
         TopicIdPartition tp1 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo1", 0));
         TopicIdPartition tp2 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo2", 0));
 
-        Map<TopicIdPartition, Integer> partitionMaxBytes = new HashMap<>();
-        partitionMaxBytes.put(tp1, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp2, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp1, tp2);
 
         SharePartition sp1 = mock(SharePartition.class);
         SharePartition sp2 = mock(SharePartition.class);
@@ -1843,9 +1832,7 @@ public class SharePartitionManagerTest {
         TopicIdPartition tp2 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo2", 0));
         TopicIdPartition tp3 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo3", 0));
 
-        Map<TopicIdPartition, Integer> partitionMaxBytes = new HashMap<>();
-        partitionMaxBytes.put(tp1, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp2, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp1, tp2);
 
         SharePartition sp1 = mock(SharePartition.class);
         SharePartition sp2 = mock(SharePartition.class);
@@ -1950,9 +1937,7 @@ public class SharePartitionManagerTest {
         TopicIdPartition tp2 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo2", 0));
         TopicIdPartition tp3 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo3", 0));
 
-        Map<TopicIdPartition, Integer> partitionMaxBytes = new HashMap<>();
-        partitionMaxBytes.put(tp1, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp2, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp1, tp2);
 
         SharePartition sp1 = mock(SharePartition.class);
         SharePartition sp2 = mock(SharePartition.class);
@@ -2053,9 +2038,7 @@ public class SharePartitionManagerTest {
         TopicIdPartition tp2 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo2", 0));
         TopicIdPartition tp3 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo3", 0));
 
-        Map<TopicIdPartition, Integer> partitionMaxBytes = new HashMap<>();
-        partitionMaxBytes.put(tp1, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp2, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp1, tp2);
 
         SharePartition sp1 = mock(SharePartition.class);
         SharePartition sp2 = mock(SharePartition.class);
@@ -2156,7 +2139,7 @@ public class SharePartitionManagerTest {
         Uuid memberId = Uuid.randomUuid();
         Uuid fooId = Uuid.randomUuid();
         TopicIdPartition tp0 = new TopicIdPartition(fooId, new 
TopicPartition("foo", 0));
-        Map<TopicIdPartition, Integer> partitionMaxBytes = 
Collections.singletonMap(tp0, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0);
 
         SharePartition sp0 = mock(SharePartition.class);
         Map<SharePartitionKey, SharePartition> partitionCacheMap = new 
HashMap<>();
@@ -2179,7 +2162,7 @@ public class SharePartitionManagerTest {
             .build();
 
         CompletableFuture<Map<TopicIdPartition, 
ShareFetchResponseData.PartitionData>> future =
-            sharePartitionManager.fetchMessages(groupId, memberId.toString(), 
FETCH_PARAMS,
+            sharePartitionManager.fetchMessages(groupId, memberId.toString(), 
FETCH_PARAMS, 0,
                 BATCH_SIZE, partitionMaxBytes);
         // Verify that the fetch request is completed.
         TestUtils.waitForCondition(
@@ -2208,7 +2191,7 @@ public class SharePartitionManagerTest {
         Uuid memberId = Uuid.randomUuid();
         Uuid fooId = Uuid.randomUuid();
         TopicIdPartition tp0 = new TopicIdPartition(fooId, new 
TopicPartition("foo", 0));
-        Map<TopicIdPartition, Integer> partitionMaxBytes = 
Collections.singletonMap(tp0, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0);
 
         SharePartition sp0 = mock(SharePartition.class);
         Map<SharePartitionKey, SharePartition> partitionCacheMap = new 
HashMap<>();
@@ -2236,15 +2219,15 @@ public class SharePartitionManagerTest {
 
         // Send 3 requests for share fetch for same share partition.
         CompletableFuture<Map<TopicIdPartition, 
ShareFetchResponseData.PartitionData>> future1 =
-            sharePartitionManager.fetchMessages(groupId, memberId.toString(), 
FETCH_PARAMS,
+            sharePartitionManager.fetchMessages(groupId, memberId.toString(), 
FETCH_PARAMS, 0,
                 BATCH_SIZE, partitionMaxBytes);
 
         CompletableFuture<Map<TopicIdPartition, 
ShareFetchResponseData.PartitionData>> future2 =
-            sharePartitionManager.fetchMessages(groupId, memberId.toString(), 
FETCH_PARAMS,
+            sharePartitionManager.fetchMessages(groupId, memberId.toString(), 
FETCH_PARAMS, 0,
                 BATCH_SIZE, partitionMaxBytes);
 
         CompletableFuture<Map<TopicIdPartition, 
ShareFetchResponseData.PartitionData>> future3 =
-            sharePartitionManager.fetchMessages(groupId, memberId.toString(), 
FETCH_PARAMS,
+            sharePartitionManager.fetchMessages(groupId, memberId.toString(), 
FETCH_PARAMS, 0,
                 BATCH_SIZE, partitionMaxBytes);
 
         Mockito.verify(sp0, times(3)).maybeInitialize();
@@ -2281,7 +2264,7 @@ public class SharePartitionManagerTest {
         Uuid memberId = Uuid.randomUuid();
         Uuid fooId = Uuid.randomUuid();
         TopicIdPartition tp0 = new TopicIdPartition(fooId, new 
TopicPartition("foo", 0));
-        Map<TopicIdPartition, Integer> partitionMaxBytes = 
Collections.singletonMap(tp0, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0);
 
         SharePartition sp0 = mock(SharePartition.class);
         Map<SharePartitionKey, SharePartition> partitionCacheMap = new 
HashMap<>();
@@ -2302,7 +2285,7 @@ public class SharePartitionManagerTest {
         // Return LeaderNotAvailableException to simulate initialization 
failure.
         when(sp0.maybeInitialize()).thenReturn(FutureUtils.failedFuture(new 
LeaderNotAvailableException("Leader not available")));
         CompletableFuture<Map<TopicIdPartition, 
ShareFetchResponseData.PartitionData>> future =
-            sharePartitionManager.fetchMessages(groupId, memberId.toString(), 
FETCH_PARAMS,
+            sharePartitionManager.fetchMessages(groupId, memberId.toString(), 
FETCH_PARAMS, 0,
                 BATCH_SIZE, partitionMaxBytes);
         TestUtils.waitForCondition(
             future::isDone,
@@ -2318,7 +2301,7 @@ public class SharePartitionManagerTest {
 
         // Return IllegalStateException to simulate initialization failure.
         when(sp0.maybeInitialize()).thenReturn(FutureUtils.failedFuture(new 
IllegalStateException("Illegal state")));
-        future = sharePartitionManager.fetchMessages(groupId, 
memberId.toString(), FETCH_PARAMS,
+        future = sharePartitionManager.fetchMessages(groupId, 
memberId.toString(), FETCH_PARAMS, 0,
             BATCH_SIZE, partitionMaxBytes);
         TestUtils.waitForCondition(
             future::isDone,
@@ -2332,7 +2315,7 @@ public class SharePartitionManagerTest {
         partitionCacheMap.put(new SharePartitionKey(groupId, tp0), sp0);
         // Return CoordinatorNotAvailableException to simulate initialization 
failure.
         when(sp0.maybeInitialize()).thenReturn(FutureUtils.failedFuture(new 
CoordinatorNotAvailableException("Coordinator not available")));
-        future = sharePartitionManager.fetchMessages(groupId, 
memberId.toString(), FETCH_PARAMS,
+        future = sharePartitionManager.fetchMessages(groupId, 
memberId.toString(), FETCH_PARAMS, 0,
             BATCH_SIZE, partitionMaxBytes);
         TestUtils.waitForCondition(
             future::isDone,
@@ -2346,7 +2329,7 @@ public class SharePartitionManagerTest {
         partitionCacheMap.put(new SharePartitionKey(groupId, tp0), sp0);
         // Return InvalidRequestException to simulate initialization failure.
         when(sp0.maybeInitialize()).thenReturn(FutureUtils.failedFuture(new 
InvalidRequestException("Invalid request")));
-        future = sharePartitionManager.fetchMessages(groupId, 
memberId.toString(), FETCH_PARAMS,
+        future = sharePartitionManager.fetchMessages(groupId, 
memberId.toString(), FETCH_PARAMS, 0,
             BATCH_SIZE, partitionMaxBytes);
         TestUtils.waitForCondition(
             future::isDone,
@@ -2360,7 +2343,7 @@ public class SharePartitionManagerTest {
         partitionCacheMap.put(new SharePartitionKey(groupId, tp0), sp0);
         // Return FencedStateEpochException to simulate initialization failure.
         when(sp0.maybeInitialize()).thenReturn(FutureUtils.failedFuture(new 
FencedStateEpochException("Fenced state epoch")));
-        future = sharePartitionManager.fetchMessages(groupId, 
memberId.toString(), FETCH_PARAMS,
+        future = sharePartitionManager.fetchMessages(groupId, 
memberId.toString(), FETCH_PARAMS, 0,
             BATCH_SIZE, partitionMaxBytes);
         TestUtils.waitForCondition(
             future::isDone,
@@ -2374,7 +2357,7 @@ public class SharePartitionManagerTest {
         partitionCacheMap.put(new SharePartitionKey(groupId, tp0), sp0);
         // Return NotLeaderOrFollowerException to simulate initialization 
failure.
         when(sp0.maybeInitialize()).thenReturn(FutureUtils.failedFuture(new 
NotLeaderOrFollowerException("Not leader or follower")));
-        future = sharePartitionManager.fetchMessages(groupId, 
memberId.toString(), FETCH_PARAMS,
+        future = sharePartitionManager.fetchMessages(groupId, 
memberId.toString(), FETCH_PARAMS, 0,
             BATCH_SIZE, partitionMaxBytes);
         TestUtils.waitForCondition(
             future::isDone,
@@ -2388,7 +2371,7 @@ public class SharePartitionManagerTest {
         partitionCacheMap.put(new SharePartitionKey(groupId, tp0), sp0);
         // Return RuntimeException to simulate initialization failure.
         when(sp0.maybeInitialize()).thenReturn(FutureUtils.failedFuture(new 
RuntimeException("Runtime exception")));
-        future = sharePartitionManager.fetchMessages(groupId, 
memberId.toString(), FETCH_PARAMS,
+        future = sharePartitionManager.fetchMessages(groupId, 
memberId.toString(), FETCH_PARAMS, 0,
             BATCH_SIZE, partitionMaxBytes);
         TestUtils.waitForCondition(
             future::isDone,
@@ -2406,13 +2389,12 @@ public class SharePartitionManagerTest {
         );
     }
 
-
     @Test
     @SuppressWarnings("unchecked")
     public void testShareFetchProcessingExceptions() throws Exception {
         String groupId = "grp";
         TopicIdPartition tp0 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 0));
-        Map<TopicIdPartition, Integer> partitionMaxBytes = 
Collections.singletonMap(tp0, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0);
 
         Map<SharePartitionKey, SharePartition> partitionCacheMap = 
(Map<SharePartitionKey, SharePartition>) mock(Map.class);
         // Throw the exception for first fetch request. Return share partition 
for next.
@@ -2425,7 +2407,7 @@ public class SharePartitionManagerTest {
             .build();
 
         CompletableFuture<Map<TopicIdPartition, 
ShareFetchResponseData.PartitionData>> future =
-            sharePartitionManager.fetchMessages(groupId, 
Uuid.randomUuid().toString(), FETCH_PARAMS,
+            sharePartitionManager.fetchMessages(groupId, 
Uuid.randomUuid().toString(), FETCH_PARAMS, 0,
                 BATCH_SIZE, partitionMaxBytes);
         TestUtils.waitForCondition(
             future::isDone,
@@ -2444,7 +2426,7 @@ public class SharePartitionManagerTest {
     public void testSharePartitionInitializationFailure() throws Exception {
         String groupId = "grp";
         TopicIdPartition tp0 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 0));
-        Map<TopicIdPartition, Integer> partitionMaxBytes = 
Collections.singletonMap(tp0, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0);
 
         // Send map to check no share partition is created.
         Map<SharePartitionKey, SharePartition> partitionCacheMap = new 
HashMap<>();
@@ -2466,7 +2448,7 @@ public class SharePartitionManagerTest {
 
         // Validate when exception is thrown.
         CompletableFuture<Map<TopicIdPartition, 
ShareFetchResponseData.PartitionData>> future =
-            sharePartitionManager.fetchMessages(groupId, 
Uuid.randomUuid().toString(), FETCH_PARAMS,
+            sharePartitionManager.fetchMessages(groupId, 
Uuid.randomUuid().toString(), FETCH_PARAMS, 0,
                 BATCH_SIZE, partitionMaxBytes);
         TestUtils.waitForCondition(
             future::isDone,
@@ -2476,7 +2458,7 @@ public class SharePartitionManagerTest {
         assertTrue(partitionCacheMap.isEmpty());
 
         // Validate when partition is not leader.
-        future = sharePartitionManager.fetchMessages(groupId, 
Uuid.randomUuid().toString(), FETCH_PARAMS,
+        future = sharePartitionManager.fetchMessages(groupId, 
Uuid.randomUuid().toString(), FETCH_PARAMS, 0,
             BATCH_SIZE, partitionMaxBytes);
         TestUtils.waitForCondition(
             future::isDone,
@@ -2502,10 +2484,7 @@ public class SharePartitionManagerTest {
         TopicIdPartition tp1 = new TopicIdPartition(memberId1, new 
TopicPartition("foo", 1));
         // For tp2, share partition initialization will fail.
         TopicIdPartition tp2 = new TopicIdPartition(memberId1, new 
TopicPartition("foo", 2));
-        Map<TopicIdPartition, Integer> partitionMaxBytes = Map.of(
-            tp0, PARTITION_MAX_BYTES,
-            tp1, PARTITION_MAX_BYTES,
-            tp2, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0, tp1, tp2);
 
         // Mark partition0 as not the leader.
         Partition partition0 = mock(Partition.class);
@@ -2546,7 +2525,7 @@ public class SharePartitionManagerTest {
 
         // Validate when exception is thrown.
         CompletableFuture<Map<TopicIdPartition, 
ShareFetchResponseData.PartitionData>> future =
-            sharePartitionManager.fetchMessages(groupId, 
Uuid.randomUuid().toString(), FETCH_PARAMS,
+            sharePartitionManager.fetchMessages(groupId, 
Uuid.randomUuid().toString(), FETCH_PARAMS, 0,
                 BATCH_SIZE, partitionMaxBytes);
         assertTrue(future.isDone());
         assertFalse(future.isCompletedExceptionally());
@@ -2578,7 +2557,7 @@ public class SharePartitionManagerTest {
         String groupId = "grp";
         Uuid memberId = Uuid.randomUuid();
         TopicIdPartition tp0 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 0));
-        Map<TopicIdPartition, Integer> partitionMaxBytes = 
Collections.singletonMap(tp0, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0);
 
         SharePartition sp0 = mock(SharePartition.class);
         when(sp0.maybeAcquireFetchLock()).thenReturn(true);
@@ -2602,7 +2581,7 @@ public class SharePartitionManagerTest {
             .build();
 
         CompletableFuture<Map<TopicIdPartition, 
ShareFetchResponseData.PartitionData>> future =
-            sharePartitionManager.fetchMessages(groupId, memberId.toString(), 
FETCH_PARAMS,
+            sharePartitionManager.fetchMessages(groupId, memberId.toString(), 
FETCH_PARAMS, 0,
                 BATCH_SIZE, partitionMaxBytes);
         validateShareFetchFutureException(future, tp0, 
Errors.UNKNOWN_SERVER_ERROR, "Exception");
         // Verify that the share partition is still in the cache on exception.
@@ -2611,7 +2590,7 @@ public class SharePartitionManagerTest {
         // Throw NotLeaderOrFollowerException from replica manager fetch which 
should evict instance from the cache.
         doThrow(new NotLeaderOrFollowerException("Leader 
exception")).when(mockReplicaManager).readFromLog(any(), any(), 
any(ReplicaQuota.class), anyBoolean());
 
-        future = sharePartitionManager.fetchMessages(groupId, 
memberId.toString(), FETCH_PARAMS,
+        future = sharePartitionManager.fetchMessages(groupId, 
memberId.toString(), FETCH_PARAMS, 0,
             BATCH_SIZE, partitionMaxBytes);
         validateShareFetchFutureException(future, tp0, 
Errors.NOT_LEADER_OR_FOLLOWER, "Leader exception");
         assertTrue(partitionCacheMap.isEmpty());
@@ -2630,9 +2609,7 @@ public class SharePartitionManagerTest {
 
         TopicIdPartition tp0 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 0));
         TopicIdPartition tp1 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("bar", 0));
-        Map<TopicIdPartition, Integer> partitionMaxBytes = new HashMap<>();
-        partitionMaxBytes.put(tp0, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp1, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0, tp1);
 
         SharePartition sp0 = mock(SharePartition.class);
         when(sp0.maybeAcquireFetchLock()).thenReturn(true);
@@ -2665,7 +2642,7 @@ public class SharePartitionManagerTest {
             .build();
 
         CompletableFuture<Map<TopicIdPartition, 
ShareFetchResponseData.PartitionData>> future =
-            sharePartitionManager.fetchMessages(groupId, memberId.toString(), 
FETCH_PARAMS,
+            sharePartitionManager.fetchMessages(groupId, memberId.toString(), 
FETCH_PARAMS, 0,
                 BATCH_SIZE, partitionMaxBytes);
         validateShareFetchFutureException(future, tp0, 
Errors.FENCED_STATE_EPOCH, "Fenced exception");
         // Verify that tp1 is still in the cache on exception.
@@ -2680,7 +2657,7 @@ public class SharePartitionManagerTest {
         // Throw FencedStateEpochException from replica manager fetch which 
should evict instance from the cache.
         doThrow(new FencedStateEpochException("Fenced exception 
again")).when(mockReplicaManager).readFromLog(any(), any(), 
any(ReplicaQuota.class), anyBoolean());
 
-        future = sharePartitionManager.fetchMessages(groupId, 
memberId.toString(), FETCH_PARAMS,
+        future = sharePartitionManager.fetchMessages(groupId, 
memberId.toString(), FETCH_PARAMS, 0,
             BATCH_SIZE, partitionMaxBytes);
         validateShareFetchFutureException(future, List.of(tp0, tp1), 
Errors.FENCED_STATE_EPOCH, "Fenced exception again");
         assertTrue(partitionCacheMap.isEmpty());
@@ -2701,9 +2678,7 @@ public class SharePartitionManagerTest {
 
         TopicIdPartition tp0 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 0));
         TopicIdPartition tp1 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("bar", 0));
-        Map<TopicIdPartition, Integer> partitionMaxBytes = new HashMap<>();
-        partitionMaxBytes.put(tp0, PARTITION_MAX_BYTES);
-        partitionMaxBytes.put(tp1, PARTITION_MAX_BYTES);
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES, tp0, tp1);
 
         ReplicaManager mockReplicaManager = mock(ReplicaManager.class);
         Partition partition = mockPartition();
@@ -2716,7 +2691,7 @@ public class SharePartitionManagerTest {
             .build();
 
         CompletableFuture<Map<TopicIdPartition, PartitionData>> future = 
sharePartitionManager.fetchMessages(
-            groupId, memberId.toString(), FETCH_PARAMS, BATCH_SIZE, 
partitionMaxBytes);
+            groupId, memberId.toString(), FETCH_PARAMS, 0, BATCH_SIZE, 
partitionMaxBytes);
         assertTrue(future.isDone());
         // Validate that the listener is registered.
         verify(mockReplicaManager, times(2)).maybeAddListener(any(), any());
@@ -2762,6 +2737,63 @@ public class SharePartitionManagerTest {
         testSharePartitionListener(sharePartitionKey, partitionCacheMap, 
mockReplicaManager, partitionListener::onBecomingFollower);
     }
 
+    @Test
+    public void testFetchMessagesRotatePartitions() {
+        String groupId = "grp";
+        Uuid memberId1 = Uuid.randomUuid();
+        TopicIdPartition tp0 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 0));
+        TopicIdPartition tp1 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 1));
+        TopicIdPartition tp2 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("bar", 0));
+        TopicIdPartition tp3 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("bar", 1));
+        TopicIdPartition tp4 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 2));
+        TopicIdPartition tp5 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("bar", 2));
+        TopicIdPartition tp6 = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 3));
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = 
orderedMap(PARTITION_MAX_BYTES,
+            tp0, tp1, tp2, tp3, tp4, tp5, tp6);
+
+        SharePartitionManager sharePartitionManager = 
Mockito.spy(SharePartitionManagerBuilder.builder().withBrokerTopicStats(brokerTopicStats).build());
+        // Capture the arguments passed to processShareFetch.
+        ArgumentCaptor<ShareFetch> captor = 
ArgumentCaptor.forClass(ShareFetch.class);
+
+        sharePartitionManager.fetchMessages(groupId, memberId1.toString(), 
FETCH_PARAMS, 0, BATCH_SIZE,
+            partitionMaxBytes);
+        verify(sharePartitionManager, 
times(1)).processShareFetch(captor.capture());
+        // Verify the partitions rotation, no rotation.
+        ShareFetch resultShareFetch = captor.getValue();
+        validateRotatedMapEquals(resultShareFetch.partitionMaxBytes(), 
partitionMaxBytes, 0);
+
+        // Single rotation.
+        sharePartitionManager.fetchMessages(groupId, memberId1.toString(), 
FETCH_PARAMS, 1, BATCH_SIZE,
+            partitionMaxBytes);
+        verify(sharePartitionManager, 
times(2)).processShareFetch(captor.capture());
+        // Verify the partitions rotation, rotate by 1.
+        resultShareFetch = captor.getValue();
+        validateRotatedMapEquals(partitionMaxBytes, 
resultShareFetch.partitionMaxBytes(), 1);
+
+        // Rotation by 3, less that the number of partitions.
+        sharePartitionManager.fetchMessages(groupId, memberId1.toString(), 
FETCH_PARAMS, 3, BATCH_SIZE,
+            partitionMaxBytes);
+        verify(sharePartitionManager, 
times(3)).processShareFetch(captor.capture());
+        // Verify the partitions rotation, rotate by 3.
+        resultShareFetch = captor.getValue();
+        validateRotatedMapEquals(partitionMaxBytes, 
resultShareFetch.partitionMaxBytes(), 3);
+
+        // Rotation by 12, more than the number of partitions.
+        sharePartitionManager.fetchMessages(groupId, memberId1.toString(), 
FETCH_PARAMS, 12, BATCH_SIZE,
+            partitionMaxBytes);
+        verify(sharePartitionManager, 
times(4)).processShareFetch(captor.capture());
+        // Verify the partitions rotation, rotate by 5 (12 % 7).
+        resultShareFetch = captor.getValue();
+        validateRotatedMapEquals(partitionMaxBytes, 
resultShareFetch.partitionMaxBytes(), 5);
+        // Rotation by Integer.MAX_VALUE, boundary test.
+        sharePartitionManager.fetchMessages(groupId, memberId1.toString(), 
FETCH_PARAMS, Integer.MAX_VALUE, BATCH_SIZE,
+            partitionMaxBytes);
+        verify(sharePartitionManager, 
times(5)).processShareFetch(captor.capture());
+        // Verify the partitions rotation, rotate by 1 (2147483647 % 7).
+        resultShareFetch = captor.getValue();
+        validateRotatedMapEquals(partitionMaxBytes, 
resultShareFetch.partitionMaxBytes(), 1);
+    }
+
     private void testSharePartitionListener(
         SharePartitionKey sharePartitionKey,
         Map<SharePartitionKey, SharePartition> partitionCacheMap,
diff --git a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala 
b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
index bcf9ec587d5..11ebee134b8 100644
--- a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
+++ b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
@@ -3897,7 +3897,7 @@ class KafkaApisTest extends Logging {
 
     val records = memoryRecords(10, 0)
 
-    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
any())).thenReturn(
+    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
anyInt(), any())).thenReturn(
       CompletableFuture.completedFuture(Map[TopicIdPartition, 
ShareFetchResponseData.PartitionData](
         new TopicIdPartition(topicId, new TopicPartition(topicName, 
partitionIndex)) ->
           new ShareFetchResponseData.PartitionData()
@@ -3969,7 +3969,7 @@ class KafkaApisTest extends Logging {
 
     val records = memoryRecords(10, 0)
 
-    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
any())).thenReturn(
+    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
anyInt(), any())).thenReturn(
       CompletableFuture.completedFuture(Map[TopicIdPartition, 
ShareFetchResponseData.PartitionData](
         new TopicIdPartition(topicId, new TopicPartition(topicName, 
partitionIndex)) ->
           new ShareFetchResponseData.PartitionData()
@@ -4074,7 +4074,7 @@ class KafkaApisTest extends Logging {
 
     val records = memoryRecords(10, 0)
 
-    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
any())).thenReturn(
+    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
anyInt(), any())).thenReturn(
       CompletableFuture.completedFuture(Map[TopicIdPartition, 
ShareFetchResponseData.PartitionData](
         new TopicIdPartition(topicId, new TopicPartition(topicName, 
partitionIndex)) ->
           new ShareFetchResponseData.PartitionData()
@@ -4178,7 +4178,7 @@ class KafkaApisTest extends Logging {
     addTopicToMetadataCache(topicName, 1, topicId = topicId)
     val memberId: Uuid = Uuid.ZERO_UUID
 
-    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
any())).thenReturn(
+    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
anyInt(), any())).thenReturn(
       FutureUtils.failedFuture[util.Map[TopicIdPartition, 
ShareFetchResponseData.PartitionData]](Errors.UNKNOWN_SERVER_ERROR.exception())
     )
 
@@ -4230,7 +4230,7 @@ class KafkaApisTest extends Logging {
 
     val records = memoryRecords(10, 0)
 
-    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
any())).thenReturn(
+    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
anyInt(), any())).thenReturn(
       CompletableFuture.completedFuture(Map[TopicIdPartition, 
ShareFetchResponseData.PartitionData](
         new TopicIdPartition(topicId, new TopicPartition(topicName, 
partitionIndex)) ->
           new ShareFetchResponseData.PartitionData()
@@ -4305,7 +4305,7 @@ class KafkaApisTest extends Logging {
 
     val groupId = "group"
 
-    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
any())).thenReturn(
+    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
anyInt(), any())).thenReturn(
       FutureUtils.failedFuture[util.Map[TopicIdPartition, 
ShareFetchResponseData.PartitionData]](Errors.UNKNOWN_SERVER_ERROR.exception())
     )
 
@@ -4369,7 +4369,7 @@ class KafkaApisTest extends Logging {
 
     val records = MemoryRecords.EMPTY
 
-    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
any())).thenReturn(
+    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
anyInt(), any())).thenReturn(
       CompletableFuture.completedFuture(Map[TopicIdPartition, 
ShareFetchResponseData.PartitionData](
         new TopicIdPartition(topicId, new TopicPartition(topicName, 
partitionIndex)) ->
           new ShareFetchResponseData.PartitionData()
@@ -4434,7 +4434,7 @@ class KafkaApisTest extends Logging {
     val groupId = "group"
     val records = memoryRecords(10, 0)
 
-    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
any())).thenReturn(
+    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
anyInt(), any())).thenReturn(
       CompletableFuture.completedFuture(Map[TopicIdPartition, 
ShareFetchResponseData.PartitionData](
         new TopicIdPartition(topicId, new TopicPartition(topicName, 
partitionIndex)) ->
           new ShareFetchResponseData.PartitionData()
@@ -4527,7 +4527,7 @@ class KafkaApisTest extends Logging {
     val groupId = "group"
     val records = memoryRecords(10, 0)
 
-    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
any())).thenReturn(
+    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
anyInt(), any())).thenReturn(
       CompletableFuture.completedFuture(Map[TopicIdPartition, 
ShareFetchResponseData.PartitionData](
         new TopicIdPartition(topicId, new TopicPartition(topicName, 
partitionIndex)) ->
           new ShareFetchResponseData.PartitionData()
@@ -4620,7 +4620,7 @@ class KafkaApisTest extends Logging {
     val records2 = memoryRecords(10, 10)
     val records3 = memoryRecords(10, 20)
 
-    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
any())).thenReturn(
+    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
anyInt(), any())).thenReturn(
       CompletableFuture.completedFuture(Map[TopicIdPartition, 
ShareFetchResponseData.PartitionData](
         new TopicIdPartition(topicId, new TopicPartition(topicName, 
partitionIndex)) ->
           new ShareFetchResponseData.PartitionData()
@@ -4848,7 +4848,7 @@ class KafkaApisTest extends Logging {
 
     val groupId = "group"
 
-    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
any())).thenReturn(
+    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
anyInt(), any())).thenReturn(
       CompletableFuture.completedFuture(Map[TopicIdPartition, 
ShareFetchResponseData.PartitionData](
         new TopicIdPartition(topicId1, new TopicPartition(topicName1, 0)) ->
           new ShareFetchResponseData.PartitionData()
@@ -5323,7 +5323,7 @@ class KafkaApisTest extends Logging {
     val tp2 = new TopicIdPartition(topicId2, new TopicPartition(topicName2, 0))
     val tp3 = new TopicIdPartition(topicId2, new TopicPartition(topicName2, 1))
 
-    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
any())).thenReturn(
+    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
anyInt(), any())).thenReturn(
       CompletableFuture.completedFuture(Map[TopicIdPartition, 
ShareFetchResponseData.PartitionData](
         tp1 ->
           new ShareFetchResponseData.PartitionData()
@@ -5419,6 +5419,7 @@ class KafkaApisTest extends Logging {
     val fetchResult: Map[TopicIdPartition, 
ShareFetchResponseData.PartitionData] =
       kafkaApis.handleFetchFromShareFetchRequest(
         request,
+        0,
         erroneousAndValidPartitionData,
         sharePartitionManager,
         authorizedTopics
@@ -5487,7 +5488,7 @@ class KafkaApisTest extends Logging {
     val tp2 = new TopicIdPartition(topicId1, new TopicPartition(topicName1, 1))
     val tp3 = new TopicIdPartition(topicId2, new TopicPartition(topicName2, 0))
 
-    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
any())).thenReturn(
+    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
anyInt(), any())).thenReturn(
       CompletableFuture.completedFuture(Map[TopicIdPartition, 
ShareFetchResponseData.PartitionData](
         tp1 ->
           new ShareFetchResponseData.PartitionData()
@@ -5565,6 +5566,7 @@ class KafkaApisTest extends Logging {
     val fetchResult: Map[TopicIdPartition, 
ShareFetchResponseData.PartitionData] =
       kafkaApis.handleFetchFromShareFetchRequest(
         request,
+        0,
         erroneousAndValidPartitionData,
         sharePartitionManager,
         authorizedTopics
@@ -5627,7 +5629,7 @@ class KafkaApisTest extends Logging {
     val tp2 = new TopicIdPartition(topicId2, new TopicPartition(topicName2, 0))
     val tp3 = new TopicIdPartition(topicId2, new TopicPartition(topicName2, 1))
 
-    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
any())).thenReturn(
+    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
anyInt(), any())).thenReturn(
       CompletableFuture.completedFuture(Map[TopicIdPartition, 
ShareFetchResponseData.PartitionData](
         tp1 ->
           new ShareFetchResponseData.PartitionData()
@@ -5708,6 +5710,7 @@ class KafkaApisTest extends Logging {
     val fetchResult: Map[TopicIdPartition, 
ShareFetchResponseData.PartitionData] =
       kafkaApis.handleFetchFromShareFetchRequest(
         request,
+        0,
         erroneousAndValidPartitionData,
         sharePartitionManager,
         authorizedTopics
@@ -5781,7 +5784,7 @@ class KafkaApisTest extends Logging {
     val tp3 = new TopicIdPartition(topicId2, new TopicPartition(topicName2, 1))
     val tp4 = new TopicIdPartition(topicId3, new TopicPartition(topicName3, 0))
 
-    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
any())).thenReturn(
+    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
anyInt(), any())).thenReturn(
       CompletableFuture.completedFuture(Map[TopicIdPartition, 
ShareFetchResponseData.PartitionData](
         tp2 ->
           new ShareFetchResponseData.PartitionData()
@@ -5877,6 +5880,7 @@ class KafkaApisTest extends Logging {
     val fetchResult: Map[TopicIdPartition, 
ShareFetchResponseData.PartitionData] =
       kafkaApis.handleFetchFromShareFetchRequest(
         request,
+        0,
         erroneousAndValidPartitionData,
         sharePartitionManager,
         authorizedTopics
@@ -5973,7 +5977,7 @@ class KafkaApisTest extends Logging {
 
     val groupId = "group"
 
-    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
any())).thenReturn(
+    when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), 
anyInt(), any())).thenReturn(
       CompletableFuture.completedFuture(Map[TopicIdPartition, 
ShareFetchResponseData.PartitionData](
         new TopicIdPartition(topicId, new TopicPartition(topicName, 
partitionIndex)) ->
           new ShareFetchResponseData.PartitionData()
diff --git 
a/server/src/main/java/org/apache/kafka/server/share/fetch/PartitionRotateStrategy.java
 
b/server/src/main/java/org/apache/kafka/server/share/fetch/PartitionRotateStrategy.java
new file mode 100644
index 00000000000..42fdaa58cce
--- /dev/null
+++ 
b/server/src/main/java/org/apache/kafka/server/share/fetch/PartitionRotateStrategy.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.server.share.fetch;
+
+import org.apache.kafka.common.TopicIdPartition;
+
+import java.util.LinkedHashMap;
+import java.util.Locale;
+import java.util.Map;
+
+/**
+ * The PartitionRotateStrategy is used to rotate the partitions based on the 
respective strategy.
+ * The share-partitions are rotated to ensure no share-partitions are starved 
from records being fetched.
+ */
+public interface PartitionRotateStrategy {
+
+    /**
+     * The strategy type to rotate the partitions.
+     */
+    enum StrategyType {
+        ROUND_ROBIN;
+
+        @Override
+        public String toString() {
+            return super.toString().toLowerCase(Locale.ROOT);
+        }
+    }
+
+    /**
+     * Rotate the partitions based on the strategy.
+     *
+     * @param topicIdPartitions the topicIdPartitions to rotate
+     * @param metadata the metadata to rotate
+     *
+     * @return the rotated topicIdPartitions
+     */
+    LinkedHashMap<TopicIdPartition, Integer> 
rotate(LinkedHashMap<TopicIdPartition, Integer> topicIdPartitions, 
PartitionRotateMetadata metadata);
+
+    static PartitionRotateStrategy type(StrategyType type) {
+        return switch (type) {
+            case ROUND_ROBIN -> PartitionRotateStrategy::rotateRoundRobin;
+        };
+    }
+
+    /**
+     * Rotate the partitions based on the round-robin strategy.
+     *
+     * @param topicIdPartitions the topicIdPartitions to rotate
+     * @param metadata the metadata to rotate
+     *
+     * @return the rotated topicIdPartitions
+     */
+    static LinkedHashMap<TopicIdPartition, Integer> rotateRoundRobin(
+        LinkedHashMap<TopicIdPartition, Integer> topicIdPartitions,
+        PartitionRotateMetadata metadata
+    ) {
+        if (topicIdPartitions.isEmpty() || topicIdPartitions.size() == 1 || 
metadata.sessionEpoch < 1) {
+            // No need to rotate the partitions if there are no partitions, 
only one partition or the
+            // session epoch is initial or final.
+            return topicIdPartitions;
+        }
+
+        int rotateAt = metadata.sessionEpoch % topicIdPartitions.size();
+        if (rotateAt == 0) {
+            // No need to rotate the partitions if the rotateAt is 0.
+            return topicIdPartitions;
+        }
+
+        // TODO: Once the partition max bytes is removed then the partition 
will be a linked list and rotation
+        //  will be a simple operation. Else consider using 
ImplicitLinkedHashCollection.
+        LinkedHashMap<TopicIdPartition, Integer> suffixPartitions = new 
LinkedHashMap<>(rotateAt);
+        LinkedHashMap<TopicIdPartition, Integer> rotatedPartitions = new 
LinkedHashMap<>(topicIdPartitions.size());
+        int i = 0;
+        for (Map.Entry<TopicIdPartition, Integer> entry : 
topicIdPartitions.entrySet()) {
+            if (i < rotateAt) {
+                suffixPartitions.put(entry.getKey(), entry.getValue());
+            } else {
+                rotatedPartitions.put(entry.getKey(), entry.getValue());
+            }
+            i++;
+        }
+        rotatedPartitions.putAll(suffixPartitions);
+        return rotatedPartitions;
+    }
+
+    /**
+     * The partition rotate metadata which can be used to store the metadata 
for the partition rotation.
+     *
+     * @param sessionEpoch the share session epoch.
+     */
+    record PartitionRotateMetadata(int sessionEpoch) { }
+}
diff --git 
a/server/src/main/java/org/apache/kafka/server/share/fetch/ShareFetch.java 
b/server/src/main/java/org/apache/kafka/server/share/fetch/ShareFetch.java
index 9d93f3fce8d..521e0807268 100644
--- a/server/src/main/java/org/apache/kafka/server/share/fetch/ShareFetch.java
+++ b/server/src/main/java/org/apache/kafka/server/share/fetch/ShareFetch.java
@@ -26,6 +26,7 @@ import org.apache.kafka.storage.log.metrics.BrokerTopicStats;
 import java.util.Collection;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.LinkedHashMap;
 import java.util.LinkedHashSet;
 import java.util.Map;
 import java.util.Set;
@@ -56,7 +57,7 @@ public class ShareFetch {
     /**
      * The maximum number of bytes that can be fetched for each partition.
      */
-    private final Map<TopicIdPartition, Integer> partitionMaxBytes;
+    private final LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes;
     /**
      * The batch size of the fetch request.
      */
@@ -80,7 +81,7 @@ public class ShareFetch {
         String groupId,
         String memberId,
         CompletableFuture<Map<TopicIdPartition, PartitionData>> future,
-        Map<TopicIdPartition, Integer> partitionMaxBytes,
+        LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes,
         int batchSize,
         int maxFetchRecords,
         BrokerTopicStats brokerTopicStats
@@ -103,7 +104,7 @@ public class ShareFetch {
         return memberId;
     }
 
-    public Map<TopicIdPartition, Integer> partitionMaxBytes() {
+    public LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes() {
         return partitionMaxBytes;
     }
 
diff --git 
a/server/src/test/java/org/apache/kafka/server/share/fetch/PartitionRotateStrategyTest.java
 
b/server/src/test/java/org/apache/kafka/server/share/fetch/PartitionRotateStrategyTest.java
new file mode 100644
index 00000000000..f05490c8747
--- /dev/null
+++ 
b/server/src/test/java/org/apache/kafka/server/share/fetch/PartitionRotateStrategyTest.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.server.share.fetch;
+
+import org.apache.kafka.common.TopicIdPartition;
+import org.apache.kafka.common.Uuid;
+import org.apache.kafka.common.requests.ShareRequestMetadata;
+import 
org.apache.kafka.server.share.fetch.PartitionRotateStrategy.PartitionRotateMetadata;
+import 
org.apache.kafka.server.share.fetch.PartitionRotateStrategy.StrategyType;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.LinkedHashMap;
+
+import static 
org.apache.kafka.server.share.fetch.ShareFetchTestUtils.validateRotatedMapEquals;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class PartitionRotateStrategyTest {
+
+    @Test
+    public void testRoundRobinStrategy() {
+        PartitionRotateStrategy strategy = 
PartitionRotateStrategy.type(StrategyType.ROUND_ROBIN);
+        LinkedHashMap<TopicIdPartition, Integer> partitions = 
createPartitions(3);
+
+        LinkedHashMap<TopicIdPartition, Integer> result = 
strategy.rotate(partitions, new PartitionRotateMetadata(1));
+        assertEquals(3, result.size());
+        validateRotatedMapEquals(partitions, result, 1);
+
+        // Session epoch is greater than the number of partitions.
+        result = strategy.rotate(partitions, new PartitionRotateMetadata(5));
+        assertEquals(3, result.size());
+        validateRotatedMapEquals(partitions, result, 2);
+
+        // Session epoch is at Integer.MAX_VALUE.
+        result = strategy.rotate(partitions, new 
PartitionRotateMetadata(Integer.MAX_VALUE));
+        assertEquals(3, result.size());
+        validateRotatedMapEquals(partitions, result, 1);
+
+        // No rotation at same size as epoch.
+        result = strategy.rotate(partitions, new PartitionRotateMetadata(3));
+        assertEquals(3, result.size());
+        validateRotatedMapEquals(partitions, result, 0);
+    }
+
+    @Test
+    public void testRoundRobinStrategyWithSpecialSessionEpochs() {
+        PartitionRotateStrategy strategy = 
PartitionRotateStrategy.type(StrategyType.ROUND_ROBIN);
+
+        LinkedHashMap<TopicIdPartition, Integer> partitions = 
createPartitions(3);
+        LinkedHashMap<TopicIdPartition, Integer> result = strategy.rotate(
+            partitions,
+            new PartitionRotateMetadata(ShareRequestMetadata.INITIAL_EPOCH));
+        assertEquals(3, result.size());
+        validateRotatedMapEquals(partitions, result, 0);
+
+        result = strategy.rotate(
+            partitions,
+            new PartitionRotateMetadata(ShareRequestMetadata.FINAL_EPOCH));
+        assertEquals(3, result.size());
+        validateRotatedMapEquals(partitions, result, 0);
+    }
+
+    @Test
+    public void testRoundRobinStrategyWithEmptyPartitions() {
+        PartitionRotateStrategy strategy = 
PartitionRotateStrategy.type(StrategyType.ROUND_ROBIN);
+        // Empty partitions.
+        LinkedHashMap<TopicIdPartition, Integer> result = strategy.rotate(new 
LinkedHashMap<>(), new PartitionRotateMetadata(5));
+        // The result should be empty.
+        assertTrue(result.isEmpty());
+    }
+
+    /**
+     * Create an ordered map of TopicIdPartition to partition max bytes.
+     * @param size The number of topic-partitions to create.
+     * @return The ordered map of TopicIdPartition to partition max bytes.
+     */
+    private LinkedHashMap<TopicIdPartition, Integer> createPartitions(int 
size) {
+        LinkedHashMap<TopicIdPartition, Integer> partitions = new 
LinkedHashMap<>();
+        for (int i = 0; i < size; i++) {
+            partitions.put(new TopicIdPartition(Uuid.randomUuid(), i, "foo" + 
i), 1 /* partition max bytes*/);
+        }
+        return partitions;
+    }
+}
diff --git 
a/server/src/test/java/org/apache/kafka/server/share/fetch/ShareFetchTest.java 
b/server/src/test/java/org/apache/kafka/server/share/fetch/ShareFetchTest.java
index f58c3588a55..faa01d14938 100644
--- 
a/server/src/test/java/org/apache/kafka/server/share/fetch/ShareFetchTest.java
+++ 
b/server/src/test/java/org/apache/kafka/server/share/fetch/ShareFetchTest.java
@@ -32,6 +32,7 @@ import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.CompletableFuture;
 
+import static 
org.apache.kafka.server.share.fetch.ShareFetchTestUtils.orderedMap;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -59,7 +60,7 @@ public class ShareFetchTest {
     public void testErrorInAllPartitions() {
         TopicIdPartition topicIdPartition = new 
TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 0));
         ShareFetch shareFetch = new ShareFetch(mock(FetchParams.class), 
GROUP_ID, MEMBER_ID, new CompletableFuture<>(),
-            Map.of(topicIdPartition, 10), BATCH_SIZE, 100, brokerTopicStats);
+            orderedMap(10, topicIdPartition), BATCH_SIZE, 100, 
brokerTopicStats);
         assertFalse(shareFetch.errorInAllPartitions());
 
         shareFetch.addErroneous(topicIdPartition, new RuntimeException());
@@ -71,8 +72,7 @@ public class ShareFetchTest {
         TopicIdPartition topicIdPartition0 = new 
TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 0));
         TopicIdPartition topicIdPartition1 = new 
TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 1));
         ShareFetch shareFetch = new ShareFetch(mock(FetchParams.class), 
GROUP_ID, MEMBER_ID, new CompletableFuture<>(),
-            Map.of(topicIdPartition0, 10, topicIdPartition1, 10), BATCH_SIZE, 
100,
-            brokerTopicStats);
+            orderedMap(10, topicIdPartition0, topicIdPartition1), BATCH_SIZE, 
100, brokerTopicStats);
         assertFalse(shareFetch.errorInAllPartitions());
 
         shareFetch.addErroneous(topicIdPartition0, new RuntimeException());
@@ -87,8 +87,7 @@ public class ShareFetchTest {
         TopicIdPartition topicIdPartition0 = new 
TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 0));
         TopicIdPartition topicIdPartition1 = new 
TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 1));
         ShareFetch shareFetch = new ShareFetch(mock(FetchParams.class), 
GROUP_ID, MEMBER_ID, new CompletableFuture<>(),
-            Map.of(topicIdPartition0, 10, topicIdPartition1, 10), BATCH_SIZE, 
100,
-            brokerTopicStats);
+            orderedMap(10, topicIdPartition0, topicIdPartition1), BATCH_SIZE, 
100, brokerTopicStats);
         Set<TopicIdPartition> result = 
shareFetch.filterErroneousTopicPartitions(Set.of(topicIdPartition0, 
topicIdPartition1));
         // No erroneous partitions, hence all partitions should be returned.
         assertEquals(2, result.size());
@@ -114,7 +113,7 @@ public class ShareFetchTest {
 
         CompletableFuture<Map<TopicIdPartition, PartitionData>> future = new 
CompletableFuture<>();
         ShareFetch shareFetch = new ShareFetch(mock(FetchParams.class), 
GROUP_ID, MEMBER_ID, future,
-            Map.of(topicIdPartition0, 10, topicIdPartition1, 10), BATCH_SIZE, 
100, brokerTopicStats);
+            orderedMap(10, topicIdPartition0, topicIdPartition1), BATCH_SIZE, 
100, brokerTopicStats);
 
         // Add both erroneous partition and complete request.
         shareFetch.addErroneous(topicIdPartition0, new RuntimeException());
@@ -135,7 +134,7 @@ public class ShareFetchTest {
 
         CompletableFuture<Map<TopicIdPartition, PartitionData>> future = new 
CompletableFuture<>();
         ShareFetch shareFetch = new ShareFetch(mock(FetchParams.class), 
GROUP_ID, MEMBER_ID, future,
-            Map.of(topicIdPartition0, 10, topicIdPartition1, 10), BATCH_SIZE, 
100, brokerTopicStats);
+            orderedMap(10, topicIdPartition0, topicIdPartition1), BATCH_SIZE, 
100, brokerTopicStats);
 
         // Add an erroneous partition and complete request.
         shareFetch.addErroneous(topicIdPartition0, new RuntimeException());
@@ -155,7 +154,7 @@ public class ShareFetchTest {
 
         CompletableFuture<Map<TopicIdPartition, PartitionData>> future = new 
CompletableFuture<>();
         ShareFetch shareFetch = new ShareFetch(mock(FetchParams.class), 
GROUP_ID, MEMBER_ID, future,
-            Map.of(topicIdPartition0, 10, topicIdPartition1, 10), BATCH_SIZE, 
100, brokerTopicStats);
+            orderedMap(10, topicIdPartition0, topicIdPartition1), BATCH_SIZE, 
100, brokerTopicStats);
 
         shareFetch.maybeCompleteWithException(List.of(topicIdPartition0, 
topicIdPartition1), new RuntimeException());
         assertEquals(2, future.join().size());
@@ -174,7 +173,7 @@ public class ShareFetchTest {
 
         CompletableFuture<Map<TopicIdPartition, PartitionData>> future = new 
CompletableFuture<>();
         ShareFetch shareFetch = new ShareFetch(mock(FetchParams.class), 
GROUP_ID, MEMBER_ID, future,
-            Map.of(topicIdPartition0, 10, topicIdPartition1, 10, 
topicIdPartition2, 10), BATCH_SIZE, 100, brokerTopicStats);
+            orderedMap(10, topicIdPartition0, topicIdPartition1, 
topicIdPartition2), BATCH_SIZE, 100, brokerTopicStats);
 
         shareFetch.maybeCompleteWithException(List.of(topicIdPartition0, 
topicIdPartition2), new RuntimeException());
         assertEquals(2, future.join().size());
@@ -192,7 +191,7 @@ public class ShareFetchTest {
 
         CompletableFuture<Map<TopicIdPartition, PartitionData>> future = new 
CompletableFuture<>();
         ShareFetch shareFetch = new ShareFetch(mock(FetchParams.class), 
GROUP_ID, MEMBER_ID, future,
-            Map.of(topicIdPartition0, 10, topicIdPartition1, 10), BATCH_SIZE, 
100, brokerTopicStats);
+            orderedMap(10, topicIdPartition0, topicIdPartition1), BATCH_SIZE, 
100, brokerTopicStats);
 
         shareFetch.addErroneous(topicIdPartition0, new RuntimeException());
         shareFetch.maybeCompleteWithException(List.of(topicIdPartition1), new 
RuntimeException());
diff --git 
a/server/src/test/java/org/apache/kafka/server/share/fetch/ShareFetchTestUtils.java
 
b/server/src/test/java/org/apache/kafka/server/share/fetch/ShareFetchTestUtils.java
new file mode 100644
index 00000000000..9a83bebf88b
--- /dev/null
+++ 
b/server/src/test/java/org/apache/kafka/server/share/fetch/ShareFetchTestUtils.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.server.share.fetch;
+
+import org.apache.kafka.common.TopicIdPartition;
+
+import java.util.LinkedHashMap;
+import java.util.Set;
+
+import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+/**
+ * Helper functions for writing share fetch unit tests.
+ */
+public class ShareFetchTestUtils {
+
+    /**
+     * Create an ordered map of TopicIdPartition to partition max bytes.
+     *
+     * @param partitionMaxBytes The maximum number of bytes that can be 
fetched for each partition.
+     * @param topicIdPartitions The topic partitions to create the map for.
+     * @return The ordered map of TopicIdPartition to partition max bytes.
+     */
+    public static LinkedHashMap<TopicIdPartition, Integer> orderedMap(int 
partitionMaxBytes, TopicIdPartition... topicIdPartitions) {
+        LinkedHashMap<TopicIdPartition, Integer> map = new LinkedHashMap<>();
+        for (TopicIdPartition tp : topicIdPartitions) {
+            map.put(tp, partitionMaxBytes);
+        }
+        return map;
+    }
+
+    /**
+     * Validate that the rotated map is equal to the original map with the 
keys rotated by the given position.
+     *
+     * @param original The original map.
+     * @param result The rotated map.
+     * @param rotationAt The position to rotate the keys at.
+     */
+    public static void validateRotatedMapEquals(
+        LinkedHashMap<TopicIdPartition, Integer> original,
+        LinkedHashMap<TopicIdPartition, Integer> result,
+        int rotationAt
+    ) {
+        Set<TopicIdPartition> originalKeys = original.keySet();
+        Set<TopicIdPartition> resultKeys = result.keySet();
+
+        TopicIdPartition[] originalKeysArray = new 
TopicIdPartition[originalKeys.size()];
+        int i = 0;
+        for (TopicIdPartition key : originalKeys) {
+            if (i < rotationAt) {
+                originalKeysArray[originalKeys.size() - rotationAt + i] = key;
+            } else {
+                originalKeysArray[i - rotationAt] = key;
+            }
+            i++;
+        }
+        assertArrayEquals(originalKeysArray, resultKeys.toArray());
+        for (TopicIdPartition key : originalKeys) {
+            assertEquals(original.get(key), result.get(key));
+        }
+    }
+}

Reply via email to