wernerdv commented on code in PR #21383:
URL: https://github.com/apache/kafka/pull/21383#discussion_r2762455791


##########
server/src/test/java/org/apache/kafka/server/FetchSessionTest.java:
##########
@@ -0,0 +1,1752 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.server;
+
+import org.apache.kafka.common.TopicIdPartition;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.Uuid;
+import org.apache.kafka.common.compress.Compression;
+import org.apache.kafka.common.message.FetchResponseData;
+import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.record.MemoryRecords;
+import org.apache.kafka.common.record.SimpleRecord;
+import org.apache.kafka.common.requests.FetchMetadata;
+import org.apache.kafka.common.requests.FetchRequest;
+import org.apache.kafka.common.requests.FetchRequest.PartitionData;
+import org.apache.kafka.common.requests.FetchResponse;
+import org.apache.kafka.common.utils.ImplicitLinkedHashCollection;
+import org.apache.kafka.server.FetchContext.FullFetchContext;
+import org.apache.kafka.server.FetchContext.IncrementalFetchContext;
+import org.apache.kafka.server.FetchContext.SessionErrorContext;
+import org.apache.kafka.server.FetchContext.SessionlessFetchContext;
+import org.apache.kafka.server.FetchSession.CachedPartition;
+import org.apache.kafka.server.FetchSession.FetchSessionCache;
+import org.apache.kafka.server.util.MockTime;
+
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+import org.junit.jupiter.params.provider.ValueSource;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import java.util.stream.Stream;
+
+import static org.apache.kafka.common.protocol.ApiKeys.FETCH;
+import static org.apache.kafka.common.requests.FetchMetadata.FINAL_EPOCH;
+import static org.apache.kafka.common.requests.FetchMetadata.INITIAL;
+import static 
org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertInstanceOf;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+@Timeout(120)
+public class FetchSessionTest {
+    private static final List<TopicIdPartition> EMPTY_PART_LIST = 
List.copyOf(new ArrayList<>());
+
+    @AfterEach
+    public void afterEach() {
+        
FetchSessionCache.METRICS_GROUP.removeMetric(FetchSession.NUM_INCREMENTAL_FETCH_SESSIONS);
+        
FetchSessionCache.METRICS_GROUP.removeMetric(FetchSession.NUM_INCREMENTAL_FETCH_PARTITIONS_CACHED);
+        
FetchSessionCache.METRICS_GROUP.removeMetric(FetchSession.INCREMENTAL_FETCH_SESSIONS_EVICTIONS_PER_SEC);
+        FetchSessionCache.COUNTER.set(0);
+    }
+
+    @Test
+    public void testNewSessionId() {
+        FetchSessionCacheShard cacheShard = new FetchSessionCacheShard(3, 100, 
Integer.MAX_VALUE, 0);
+        for (int i = 0; i < 10_000; i++) {
+            int id = cacheShard.newSessionId();
+            assertTrue(id > 0);
+        }
+    }
+
+    @Test
+    public void testSessionCache() {
+        FetchSessionCacheShard cacheShard = new FetchSessionCacheShard(3, 100, 
Integer.MAX_VALUE, 0);
+        assertEquals(0, cacheShard.size());
+
+        int id1 = cacheShard.maybeCreateSession(0, false, 10, true, () -> 
dummyCreate(10));
+        int id2 = cacheShard.maybeCreateSession(10, false, 20, true, () -> 
dummyCreate(20));
+        int id3 = cacheShard.maybeCreateSession(20, false, 30, true, () -> 
dummyCreate(30));
+        assertEquals(INVALID_SESSION_ID, cacheShard.maybeCreateSession(30, 
false, 40, true, () -> dummyCreate(40)));
+        assertEquals(INVALID_SESSION_ID, cacheShard.maybeCreateSession(40, 
false, 5, true, () -> dummyCreate(5)));
+        assertCacheContains(cacheShard, id1, id2, id3);
+
+        cacheShard.touch(cacheShard.get(id1).orElseThrow(), 200);
+        int id4 = cacheShard.maybeCreateSession(210, false, 11, true, () -> 
dummyCreate(11));
+        assertCacheContains(cacheShard, id1, id3, id4);
+
+        cacheShard.touch(cacheShard.get(id1).orElseThrow(), 400);
+        cacheShard.touch(cacheShard.get(id3).orElseThrow(), 390);
+        cacheShard.touch(cacheShard.get(id4).orElseThrow(), 400);
+        int id5 = cacheShard.maybeCreateSession(410, false, 50, true, () -> 
dummyCreate(50));
+        assertCacheContains(cacheShard, id3, id4, id5);
+        assertEquals(INVALID_SESSION_ID, cacheShard.maybeCreateSession(410, 
false, 5, true, () -> dummyCreate(5)));
+
+        int id6 = cacheShard.maybeCreateSession(410, true, 5, true, () -> 
dummyCreate(5));
+        assertCacheContains(cacheShard, id3, id5, id6);
+    }
+
+    @Test
+    public void testResizeCachedSessions() {
+        FetchSessionCacheShard cacheShard = new FetchSessionCacheShard(2, 100, 
Integer.MAX_VALUE, 0);
+        assertEquals(0, cacheShard.totalPartitions());
+        assertEquals(0, cacheShard.size());
+        assertEquals(0, cacheShard.evictionsMeter().count());
+
+        int id1 = cacheShard.maybeCreateSession(0, false, 2, true, () -> 
dummyCreate(2));
+        assertTrue(id1 > 0);
+        assertCacheContains(cacheShard, id1);
+
+        FetchSession session1 = cacheShard.get(id1).orElseThrow();
+        assertEquals(2, session1.size());
+        assertEquals(2, cacheShard.totalPartitions());
+        assertEquals(1, cacheShard.size());
+        assertEquals(0, cacheShard.evictionsMeter().count());
+
+        int id2 = cacheShard.maybeCreateSession(0, false, 4, true, () -> 
dummyCreate(4));
+        FetchSession session2 = cacheShard.get(id2).orElseThrow();
+        assertTrue(id2 > 0);
+        assertCacheContains(cacheShard, id1, id2);
+        assertEquals(6, cacheShard.totalPartitions());
+        assertEquals(2, cacheShard.size());
+        assertEquals(0, cacheShard.evictionsMeter().count());
+
+        cacheShard.touch(session1, 200);
+        cacheShard.touch(session2, 200);
+        int id3 = cacheShard.maybeCreateSession(200, false, 5, true, () -> 
dummyCreate(5));
+        assertTrue(id3 > 0);
+        assertCacheContains(cacheShard, id2, id3);
+        assertEquals(9, cacheShard.totalPartitions());
+        assertEquals(2, cacheShard.size());
+        assertEquals(1, cacheShard.evictionsMeter().count());
+
+        cacheShard.remove(id3);
+        assertCacheContains(cacheShard, id2);
+        assertEquals(1, cacheShard.size());
+        assertEquals(1, cacheShard.evictionsMeter().count());
+        assertEquals(4, cacheShard.totalPartitions());
+
+        Iterator<CachedPartition> iter = session2.partitionMap().iterator();
+        iter.next();
+        iter.remove();
+        assertEquals(3, session2.size());
+        assertEquals(4, session2.cachedSize());
+
+        cacheShard.touch(session2, session2.lastUsedMs());
+        assertEquals(3, cacheShard.totalPartitions());
+    }
+
+    @Test
+    public void testCachedLeaderEpoch() {
+        FetchSessionCacheShard cacheShard = new FetchSessionCacheShard(10, 
1000, Integer.MAX_VALUE, 0);
+        FetchManager fetchManager = new FetchManager(new MockTime(), 
cacheShard);
+
+        Map<String, Uuid> topicIds = Map.of("foo", Uuid.randomUuid(), "bar", 
Uuid.randomUuid());
+        TopicIdPartition tp0 = new TopicIdPartition(topicIds.get("foo"), new 
TopicPartition("foo", 0));
+        TopicIdPartition tp1 = new TopicIdPartition(topicIds.get("foo"), new 
TopicPartition("foo", 1));
+        TopicIdPartition tp2 = new TopicIdPartition(topicIds.get("bar"), new 
TopicPartition("bar", 1));
+        Map<Uuid, String> topicNames = 
topicIds.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, 
Map.Entry::getKey));
+
+        LinkedHashMap<TopicPartition, FetchRequest.PartitionData> requestData1 
= new LinkedHashMap<>();
+        requestData1.put(tp0.topicPartition(), new 
PartitionData(tp0.topicId(), 0, 0, 100, Optional.empty()));
+        requestData1.put(tp1.topicPartition(), new 
PartitionData(tp1.topicId(), 10, 0, 100, Optional.of(1)));
+        requestData1.put(tp2.topicPartition(), new 
PartitionData(tp2.topicId(), 10, 0, 100, Optional.of(2)));
+
+        FetchRequest request1 = createRequest(INITIAL, requestData1, 
EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext context1 = newContext(fetchManager, request1, topicNames);
+        Map<TopicIdPartition, Optional<Integer>> epochs1 = 
cachedLeaderEpochs(context1);
+        assertEquals(Optional.empty(), epochs1.get(tp0));
+        assertEquals(Optional.of(1), epochs1.get(tp1));
+        assertEquals(Optional.of(2), epochs1.get(tp2));
+
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
response = new LinkedHashMap<>();
+        response.put(tp0, new FetchResponseData.PartitionData()
+            .setPartitionIndex(tp0.partition())
+            .setHighWatermark(100)
+            .setLastStableOffset(100)
+            .setLogStartOffset(100));
+        response.put(tp1, new FetchResponseData.PartitionData()
+            .setPartitionIndex(tp1.partition())
+            .setHighWatermark(10)
+            .setLastStableOffset(10)
+            .setLogStartOffset(10));
+        response.put(tp2, new FetchResponseData.PartitionData()
+            .setPartitionIndex(tp2.partition())
+            .setHighWatermark(5)
+            .setLastStableOffset(5)
+            .setLogStartOffset(5));
+
+        int sessionId = context1.updateAndGenerateResponseData(response, 
List.of()).sessionId();
+
+        // With no changes, the cached epochs should remain the same
+        FetchRequest request2 = createRequest(new FetchMetadata(sessionId, 1), 
new LinkedHashMap<>(),
+            EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext context2 = newContext(fetchManager, request2, topicNames);
+        Map<TopicIdPartition, Optional<Integer>> epochs2 = 
cachedLeaderEpochs(context2);
+        assertEquals(Optional.empty(), epochs1.get(tp0));
+        assertEquals(Optional.of(1), epochs2.get(tp1));
+        assertEquals(Optional.of(2), epochs2.get(tp2));
+        context2.updateAndGenerateResponseData(response, 
List.of()).sessionId();
+
+        // Now verify we can change the leader epoch and the context is updated
+        LinkedHashMap<TopicPartition, FetchRequest.PartitionData> requestData3 
= new LinkedHashMap<>();
+        requestData3.put(tp0.topicPartition(), new 
PartitionData(tp0.topicId(), 0, 0, 100, Optional.of(6)));
+        requestData3.put(tp1.topicPartition(), new 
PartitionData(tp1.topicId(), 10, 0, 100, Optional.empty()));
+        requestData3.put(tp2.topicPartition(), new 
PartitionData(tp2.topicId(), 10, 0, 100, Optional.of(3)));
+
+        FetchRequest request3 = createRequest(new FetchMetadata(sessionId, 2), 
requestData3,
+            EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext context3 = newContext(fetchManager, request3, topicNames);
+        Map<TopicIdPartition, Optional<Integer>> epochs3 = 
cachedLeaderEpochs(context3);
+        assertEquals(Optional.of(6), epochs3.get(tp0));
+        assertEquals(Optional.empty(), epochs3.get(tp1));
+        assertEquals(Optional.of(3), epochs3.get(tp2));
+    }
+
+    @Test
+    public void testLastFetchedEpoch() {
+        FetchSessionCacheShard cacheShard = new FetchSessionCacheShard(10, 
1000, Integer.MAX_VALUE, 0);
+        FetchManager fetchManager = new FetchManager(new MockTime(), 
cacheShard);
+
+        Map<String, Uuid> topicIds = Map.of("foo", Uuid.randomUuid(), "bar", 
Uuid.randomUuid());
+        Map<Uuid, String> topicNames = 
topicIds.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, 
Map.Entry::getKey));
+        TopicIdPartition tp0 = new TopicIdPartition(topicIds.get("foo"), new 
TopicPartition("foo", 0));
+        TopicIdPartition tp1 = new TopicIdPartition(topicIds.get("foo"), new 
TopicPartition("foo", 1));
+        TopicIdPartition tp2 = new TopicIdPartition(topicIds.get("bar"), new 
TopicPartition("bar", 1));
+
+        LinkedHashMap<TopicPartition, FetchRequest.PartitionData> requestData1 
= new LinkedHashMap<>();
+        requestData1.put(tp0.topicPartition(), new 
PartitionData(tp0.topicId(), 0, 0, 100, Optional.empty(), Optional.empty()));
+        requestData1.put(tp1.topicPartition(), new 
PartitionData(tp1.topicId(), 10, 0, 100, Optional.of(1), Optional.empty()));
+        requestData1.put(tp2.topicPartition(), new 
PartitionData(tp2.topicId(), 10, 0, 100, Optional.of(2), Optional.of(1)));
+
+        FetchRequest request1 = createRequest(INITIAL, requestData1, 
EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext context1 = newContext(fetchManager, request1, topicNames);
+        assertEquals(Map.of(tp0, Optional.empty(), tp1, Optional.of(1), tp2, 
Optional.of(2)), cachedLeaderEpochs(context1));
+        assertEquals(Map.of(tp0, Optional.empty(), tp1, Optional.empty(), tp2, 
Optional.of(1)), cachedLastFetchedEpochs(context1));
+
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
response = new LinkedHashMap<>();
+        response.put(tp0, new FetchResponseData.PartitionData()
+            .setPartitionIndex(tp0.partition())
+            .setHighWatermark(100)
+            .setLastStableOffset(100)
+            .setLogStartOffset(100));
+        response.put(tp1, new FetchResponseData.PartitionData()
+            .setPartitionIndex(tp1.partition())
+            .setHighWatermark(10)
+            .setLastStableOffset(10)
+            .setLogStartOffset(10));
+        response.put(tp2, new FetchResponseData.PartitionData()
+            .setPartitionIndex(tp2.partition())
+            .setHighWatermark(5)
+            .setLastStableOffset(5)
+            .setLogStartOffset(5));
+
+        int sessionId = context1.updateAndGenerateResponseData(response, 
List.of()).sessionId();
+
+        // With no changes, the cached epochs should remain the same
+        FetchRequest request2 = createRequest(new FetchMetadata(sessionId, 1), 
new LinkedHashMap<>(),
+            EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext context2 = newContext(fetchManager, request2, topicNames);
+        assertEquals(Map.of(tp0, Optional.empty(), tp1, Optional.of(1), tp2, 
Optional.of(2)), cachedLeaderEpochs(context2));
+        assertEquals(Map.of(tp0, Optional.empty(), tp1, Optional.empty(), tp2, 
Optional.of(1)), cachedLastFetchedEpochs(context2));
+        context2.updateAndGenerateResponseData(response, 
List.of()).sessionId();
+
+        // Now verify we can change the leader epoch and the context is updated
+        LinkedHashMap<TopicPartition, FetchRequest.PartitionData> requestData3 
= new LinkedHashMap<>();
+        requestData3.put(tp0.topicPartition(), new 
PartitionData(tp0.topicId(), 0, 0, 100, Optional.of(6), Optional.of(5)));
+        requestData3.put(tp1.topicPartition(), new 
PartitionData(tp1.topicId(), 10, 0, 100, Optional.empty(), Optional.empty()));
+        requestData3.put(tp2.topicPartition(), new 
PartitionData(tp2.topicId(), 10, 0, 100, Optional.of(3), Optional.of(3)));
+
+        FetchRequest request3 = createRequest(new FetchMetadata(sessionId, 2), 
requestData3, EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext context3 = newContext(fetchManager, request3, topicNames);
+        assertEquals(Map.of(tp0, Optional.of(6), tp1, Optional.empty(), tp2, 
Optional.of(3)), cachedLeaderEpochs(context3));
+        assertEquals(Map.of(tp0, Optional.of(5), tp1, Optional.empty(), tp2, 
Optional.of(3)), cachedLastFetchedEpochs(context2));
+    }
+
+    @Test
+    public void testFetchRequests() {
+        FetchSessionCacheShard cacheShard = new FetchSessionCacheShard(10, 
1000, Integer.MAX_VALUE, 0);
+        FetchManager fetchManager = new FetchManager(new MockTime(), 
cacheShard);
+        Map<Uuid, String> topicNames = Map.of(Uuid.randomUuid(), "foo", 
Uuid.randomUuid(), "bar");
+        Map<String, Uuid> topicIds = 
topicNames.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, 
Map.Entry::getKey));
+        TopicIdPartition tp0 = new TopicIdPartition(topicIds.get("foo"), new 
TopicPartition("foo", 0));
+        TopicIdPartition tp1 = new TopicIdPartition(topicIds.get("foo"), new 
TopicPartition("foo", 1));
+        TopicIdPartition tp2 = new TopicIdPartition(topicIds.get("bar"), new 
TopicPartition("bar", 0));
+        TopicIdPartition tp3 = new TopicIdPartition(topicIds.get("bar"), new 
TopicPartition("bar", 1));
+
+        // Verify that SESSIONLESS requests get a SessionlessFetchContext
+        FetchRequest request = createRequest(FetchMetadata.LEGACY, new 
HashMap<>(), EMPTY_PART_LIST, true, FETCH.latestVersion());
+        FetchContext context = newContext(fetchManager, request, topicNames);
+        assertInstanceOf(SessionlessFetchContext.class, context);
+
+        // Create a new fetch session with a FULL fetch request
+        LinkedHashMap<TopicPartition, FetchRequest.PartitionData> reqData2 = 
new LinkedHashMap<>();
+        reqData2.put(tp0.topicPartition(), new PartitionData(tp0.topicId(), 0, 
0, 100, Optional.empty()));
+        reqData2.put(tp1.topicPartition(), new PartitionData(tp1.topicId(), 
10, 0, 100, Optional.empty()));
+        FetchRequest request2 = createRequest(INITIAL, reqData2, 
EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext context2 = newContext(fetchManager, request2, topicNames);
+        assertInstanceOf(FullFetchContext.class, context2);
+
+        Iterator<Map.Entry<TopicPartition, FetchRequest.PartitionData>> 
reqData2Iter = reqData2.entrySet().iterator();
+        context2.foreachPartition((topicIdPart, data) -> {
+            Map.Entry<TopicPartition, FetchRequest.PartitionData> entry = 
reqData2Iter.next();
+            assertEquals(entry.getKey(), topicIdPart.topicPartition());
+            assertEquals(topicIds.get(entry.getKey().topic()), 
topicIdPart.topicId());
+            assertEquals(entry.getValue(), data);
+        });
+        assertEquals(0, context2.getFetchOffset(tp0).orElseThrow());
+        assertEquals(10, context2.getFetchOffset(tp1).orElseThrow());
+
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
respData2 = new LinkedHashMap<>();
+        respData2.put(tp0, new FetchResponseData.PartitionData()
+            .setPartitionIndex(0)
+            .setHighWatermark(100)
+            .setLastStableOffset(100)
+            .setLogStartOffset(100));
+        respData2.put(tp1, new FetchResponseData.PartitionData()
+            .setPartitionIndex(1)
+            .setHighWatermark(10)
+            .setLastStableOffset(10)
+            .setLogStartOffset(10));
+        FetchResponse resp2 = 
context2.updateAndGenerateResponseData(respData2, List.of());
+        assertEquals(Errors.NONE, resp2.error());
+        assertTrue(resp2.sessionId() != INVALID_SESSION_ID);
+        assertEquals(
+            respData2.entrySet()
+                .stream()
+                .collect(Collectors.toMap(entry -> 
entry.getKey().topicPartition(), Map.Entry::getValue)),
+            resp2.responseData(topicNames, request2.version())
+        );
+
+        // Test trying to create a new session with an invalid epoch
+        FetchRequest request3 = createRequest(new 
FetchMetadata(resp2.sessionId(), 5), reqData2,
+            EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext context3 = newContext(fetchManager, request3, topicNames);
+        assertInstanceOf(SessionErrorContext.class, context3);
+        assertEquals(Errors.INVALID_FETCH_SESSION_EPOCH, 
context3.updateAndGenerateResponseData(respData2, List.of()).error());
+
+        // Test trying to create a new session with a non-existent session id
+        FetchRequest request4 = createRequest(new 
FetchMetadata(resp2.sessionId() + 1, 1), reqData2,
+            EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext context4 = newContext(fetchManager, request4, topicNames);
+        assertEquals(Errors.FETCH_SESSION_ID_NOT_FOUND, 
context4.updateAndGenerateResponseData(respData2, List.of()).error());
+
+        // Continue the first fetch session we created.
+        FetchRequest request5 = createRequest(new 
FetchMetadata(resp2.sessionId(), 1), new LinkedHashMap<>(),
+            EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext context5 = newContext(fetchManager, request5, topicNames);
+        assertInstanceOf(IncrementalFetchContext.class, context5);
+
+        Iterator<Map.Entry<TopicPartition, FetchRequest.PartitionData>> 
reqData5Iter = reqData2.entrySet().iterator();
+        context5.foreachPartition((topicIdPart, data) -> {
+            Map.Entry<TopicPartition, FetchRequest.PartitionData> entry = 
reqData5Iter.next();
+            assertEquals(entry.getKey(), topicIdPart.topicPartition());
+            assertEquals(topicIds.get(entry.getKey().topic()), 
topicIdPart.topicId());
+            assertEquals(entry.getValue(), data);
+        });
+        assertEquals(10, context5.getFetchOffset(tp1).orElseThrow());
+
+        FetchResponse resp5 = 
context5.updateAndGenerateResponseData(respData2, List.of());
+        assertEquals(Errors.NONE, resp5.error());
+        assertEquals(resp2.sessionId(), resp5.sessionId());
+        assertEquals(0, resp5.responseData(topicNames, 
request5.version()).size());
+
+        // Test setting an invalid fetch session epoch.
+        FetchRequest request6 = createRequest(new 
FetchMetadata(resp2.sessionId(), 5), reqData2,
+            EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext context6 = newContext(fetchManager, request6, topicNames);
+        assertInstanceOf(SessionErrorContext.class, context6);
+        assertEquals(Errors.INVALID_FETCH_SESSION_EPOCH, 
context6.updateAndGenerateResponseData(respData2, List.of()).error());
+
+        // Test generating a throttled response for the incremental fetch 
session
+        FetchRequest request7 = createRequest(new 
FetchMetadata(resp2.sessionId(), 2), new LinkedHashMap<>(),
+            EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext context7 = newContext(fetchManager, request7, topicNames);
+        FetchResponse resp7 = context7.getThrottledResponse(100, List.of());
+        assertEquals(Errors.NONE, resp7.error());
+        assertEquals(resp2.sessionId(), resp7.sessionId());
+        assertEquals(100, resp7.throttleTimeMs());
+
+        // Close the incremental fetch session.
+        int prevSessionId = resp5.sessionId();
+        int nextSessionId;
+        do {
+            LinkedHashMap<TopicPartition, FetchRequest.PartitionData> reqData8 
= new LinkedHashMap<>();
+            reqData8.put(tp2.topicPartition(), new 
PartitionData(tp2.topicId(), 0, 0, 100, Optional.empty()));
+            reqData8.put(tp3.topicPartition(), new 
PartitionData(tp3.topicId(), 10, 0, 100, Optional.empty()));
+            FetchRequest request8 = createRequest(new 
FetchMetadata(prevSessionId, FINAL_EPOCH), reqData8,
+                EMPTY_PART_LIST, false, FETCH.latestVersion());
+            FetchContext context8 = newContext(fetchManager, request8, 
topicNames);
+            assertInstanceOf(SessionlessFetchContext.class, context8);
+            assertEquals(0, cacheShard.size());
+
+            LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
respData8 = new LinkedHashMap<>();
+            respData8.put(tp2, new FetchResponseData.PartitionData()
+                .setPartitionIndex(0)
+                .setHighWatermark(100)
+                .setLastStableOffset(100)
+                .setLogStartOffset(100));
+            respData8.put(tp3, new FetchResponseData.PartitionData()
+                .setPartitionIndex(1)
+                .setHighWatermark(100)
+                .setLastStableOffset(100)
+                .setLogStartOffset(100));
+            FetchResponse resp8 = 
context8.updateAndGenerateResponseData(respData8, List.of());
+            assertEquals(Errors.NONE, resp8.error());
+
+            nextSessionId = resp8.sessionId();
+        } while (nextSessionId == prevSessionId);
+    }
+
+    @ParameterizedTest
+    @ValueSource(booleans = {true, false})
+    public void testIncrementalFetchSession(boolean usesTopicIds) {
+        FetchSessionCacheShard cacheShard = new FetchSessionCacheShard(10, 
1000, Integer.MAX_VALUE, 0);
+        FetchManager fetchManager = new FetchManager(new MockTime(), 
cacheShard);
+        Map<Uuid, String> topicNames = usesTopicIds
+            ? Map.of(Uuid.randomUuid(), "foo", Uuid.randomUuid(), "bar")
+            : Map.of();
+        Map<String, Uuid> topicIds = 
topicNames.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, 
Map.Entry::getKey));
+        short version = usesTopicIds ? FETCH.latestVersion() : (short) 12;
+        Uuid fooId = topicIds.getOrDefault("foo", Uuid.ZERO_UUID);
+        Uuid barId = topicIds.getOrDefault("bar", Uuid.ZERO_UUID);
+        TopicIdPartition tp0 = new TopicIdPartition(fooId, new 
TopicPartition("foo", 0));
+        TopicIdPartition tp1 = new TopicIdPartition(fooId, new 
TopicPartition("foo", 1));
+        TopicIdPartition tp2 = new TopicIdPartition(barId, new 
TopicPartition("bar", 0));
+
+        // Create a new fetch session with foo-0 and foo-1
+        LinkedHashMap<TopicPartition, FetchRequest.PartitionData> reqData1 = 
new LinkedHashMap<>();
+        reqData1.put(tp0.topicPartition(), new PartitionData(fooId, 0, 0, 100, 
Optional.empty()));
+        reqData1.put(tp1.topicPartition(), new PartitionData(fooId, 10, 0, 
100, Optional.empty()));
+        FetchRequest request1 = createRequest(INITIAL, reqData1, 
EMPTY_PART_LIST, false, version);
+        FetchContext context1 = newContext(fetchManager, request1, topicNames);
+        assertInstanceOf(FullFetchContext.class, context1);
+
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
respData1 = new LinkedHashMap<>();
+        respData1.put(tp0, new FetchResponseData.PartitionData()
+            .setPartitionIndex(0)
+            .setHighWatermark(100)
+            .setLastStableOffset(100)
+            .setLogStartOffset(100));
+        respData1.put(tp1, new FetchResponseData.PartitionData()
+            .setPartitionIndex(1)
+            .setHighWatermark(10)
+            .setLastStableOffset(10)
+            .setLogStartOffset(10));
+        FetchResponse resp1 = 
context1.updateAndGenerateResponseData(respData1, List.of());
+        assertEquals(Errors.NONE, resp1.error());
+        assertTrue(resp1.sessionId() != INVALID_SESSION_ID);
+        assertEquals(2, resp1.responseData(topicNames, 
request1.version()).size());
+
+        // Create an incremental fetch request that removes foo-0 and adds 
bar-0
+        LinkedHashMap<TopicPartition, FetchRequest.PartitionData> reqData2 = 
new LinkedHashMap<>();
+        reqData2.put(tp2.topicPartition(), new PartitionData(barId, 15, 0, 0, 
Optional.empty()));
+        FetchRequest request2 = createRequest(new 
FetchMetadata(resp1.sessionId(), 1), reqData2, List.of(tp0), false, version);
+        FetchContext context2 = newContext(fetchManager, request2, topicNames);
+        assertInstanceOf(IncrementalFetchContext.class, context2);
+
+        Set<TopicIdPartition> parts = new LinkedHashSet<>();
+        parts.add(tp1);
+        parts.add(tp2);

Review Comment:
   I tried it, sometimes `testIncrementalFetchSession()` test fails with an 
error:
   ```
   org.opentest4j.AssertionFailedError: expected: 
<cJYw5cd4SeidggxC6nx8-Q:bar-0> but was: <cWswQ2D7Sk6_xKCObHD00w:foo-1>
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to