This is an automated email from the ASF dual-hosted git repository.

lianetm pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 057460e807a KAFKA-17182: Consumer fetch sessions are evicted too 
quickly with AsyncKafkaConsumer (#18795)
057460e807a is described below

commit 057460e807a218b169c69ae0b849e0b3b4d27f42
Author: Kirk True <[email protected]>
AuthorDate: Thu Feb 13 10:53:56 2025 -0800

    KAFKA-17182: Consumer fetch sessions are evicted too quickly with 
AsyncKafkaConsumer (#18795)
    
    Reviewers: Jun Rao <[email protected]>, Lianet Magrans 
<[email protected]>, Jeff Kim <[email protected]>
---
 .../clients/consumer/internals/AbstractFetch.java  | 153 ++++++-
 .../internals/FetchRequestManagerTest.java         | 504 ++++++++++++++++++++-
 .../clients/consumer/internals/FetcherTest.java    |  49 +-
 3 files changed, 633 insertions(+), 73 deletions(-)

diff --git 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractFetch.java
 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractFetch.java
index e3d4eb58af4..8451ded8d85 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractFetch.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractFetch.java
@@ -22,6 +22,7 @@ import org.apache.kafka.clients.FetchSessionHandler;
 import org.apache.kafka.clients.KafkaClient;
 import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.NetworkClientUtils;
+import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.common.Cluster;
 import org.apache.kafka.common.Node;
 import org.apache.kafka.common.TopicPartition;
@@ -44,6 +45,8 @@ import org.slf4j.helpers.MessageFormatter;
 
 import java.io.Closeable;
 import java.time.Duration;
+import java.util.Collection;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
@@ -315,17 +318,15 @@ public abstract class AbstractFetch implements Closeable {
     }
 
     /**
-     * Return the list of <em>fetchable</em> partitions, which are the set of 
partitions to which we are subscribed,
+     * Return the set of <em>fetchable</em> partitions, which are the set of 
partitions to which we are subscribed,
      * but <em>excluding</em> any partitions for which we still have buffered 
data. The idea is that since the user
      * has yet to process the data for the partition that has already been 
fetched, we should not go send for more data
      * until the previously-fetched data has been processed.
      *
+     * @param buffered The set of partitions we have in our buffer
      * @return {@link Set} of {@link TopicPartition topic partitions} for 
which we should fetch data
      */
-    private Set<TopicPartition> fetchablePartitions() {
-        // This is the set of partitions we have in our buffer
-        Set<TopicPartition> buffered = fetchBuffer.bufferedPartitions();
-
+    private Set<TopicPartition> fetchablePartitions(Set<TopicPartition> 
buffered) {
         // This is the test that returns true if the partition is *not* 
buffered
         Predicate<TopicPartition> isNotBuffered = tp -> !buffered.contains(tp);
 
@@ -408,22 +409,28 @@ public abstract class AbstractFetch implements Closeable {
         long currentTimeMs = time.milliseconds();
         Map<String, Uuid> topicIds = metadata.topicIds();
 
-        for (TopicPartition partition : fetchablePartitions()) {
-            SubscriptionState.FetchPosition position = 
subscriptions.position(partition);
+        // This is the set of partitions that have buffered data
+        Set<TopicPartition> buffered = 
Collections.unmodifiableSet(fetchBuffer.bufferedPartitions());
+
+        // This is the set of partitions that do not have buffered data
+        Set<TopicPartition> unbuffered = fetchablePartitions(buffered);
 
-            if (position == null)
-                throw new IllegalStateException("Missing position for 
fetchable partition " + partition);
+        if (unbuffered.isEmpty()) {
+            // If there are no partitions that don't already have data locally 
buffered, there's no need to issue
+            // any fetch requests at the present time.
+            return Collections.emptyMap();
+        }
 
-            Optional<Node> leaderOpt = position.currentLeader.leader;
+        Set<Integer> bufferedNodes = bufferedNodes(buffered, currentTimeMs);
 
-            if (leaderOpt.isEmpty()) {
-                log.debug("Requesting metadata update for partition {} since 
the position {} is missing the current leader node", partition, position);
-                metadata.requestUpdate(false);
+        for (TopicPartition partition : unbuffered) {
+            SubscriptionState.FetchPosition position = 
positionForPartition(partition);
+            Optional<Node> nodeOpt = maybeNodeForPosition(partition, position, 
currentTimeMs);
+
+            if (nodeOpt.isEmpty())
                 continue;
-            }
 
-            // Use the preferred read replica if set, otherwise the 
partition's leader
-            Node node = selectReadReplica(partition, leaderOpt.get(), 
currentTimeMs);
+            Node node = nodeOpt.get();
 
             if (isUnavailable(node)) {
                 maybeThrowAuthFailure(node);
@@ -432,7 +439,14 @@ public abstract class AbstractFetch implements Closeable {
                 // going to be failed anyway before being sent, so skip 
sending the request for now
                 log.trace("Skipping fetch for partition {} because node {} is 
awaiting reconnect backoff", partition, node);
             } else if (nodesWithPendingFetchRequests.contains(node.id())) {
+                // If there's already an inflight request for this node, don't 
issue another request.
                 log.trace("Skipping fetch for partition {} because previous 
request to {} has not been processed", partition, node);
+            } else if (bufferedNodes.contains(node.id())) {
+                // While a node has buffered data, don't fetch other partition 
data from it. Because the buffered
+                // partitions are not included in the fetch request, those 
partitions will be inadvertently dropped
+                // from the broker fetch session cache. In some cases, that 
could lead to the entire fetch session
+                // being evicted.
+                log.trace("Skipping fetch for partition {} because its leader 
node {} hosts buffered partitions", partition, node);
             } else {
                 // if there is a leader and no in-flight requests, issue a new 
fetch
                 FetchSessionHandler.Builder builder = 
fetchable.computeIfAbsent(node, k -> {
@@ -456,6 +470,113 @@ public abstract class AbstractFetch implements Closeable {
         return 
fetchable.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> 
e.getValue().build()));
     }
 
+    /**
+     * Simple utility method that returns a {@link 
SubscriptionState.FetchPosition position} for the partition. If
+     * no position exists, an {@link IllegalStateException} is thrown.
+     */
+    private SubscriptionState.FetchPosition 
positionForPartition(TopicPartition partition) {
+        SubscriptionState.FetchPosition position = 
subscriptions.position(partition);
+
+        if (position == null)
+            throw new IllegalStateException("Missing position for fetchable 
partition " + partition);
+
+        return position;
+    }
+
+    /**
+     * Retrieves the node from which to fetch the partition data. If the given
+     * {@link SubscriptionState.FetchPosition position} does not have a current
+     * {@link Metadata.LeaderAndEpoch#leader leader} defined the method will 
return {@link Optional#empty()}.
+     *
+     * @return Three options: 1) {@link Optional#empty()} if the position's 
leader is empty, 2) the
+     * {@link #selectReadReplica(TopicPartition, Node, long) read replica, if 
defined}, or 3) the position's
+     * {@link Metadata.LeaderAndEpoch#leader leader}
+     */
+    private Optional<Node> maybeNodeForPosition(TopicPartition partition,
+                                                
SubscriptionState.FetchPosition position,
+                                                long currentTimeMs) {
+        Optional<Node> leaderOpt = position.currentLeader.leader;
+
+        if (leaderOpt.isEmpty()) {
+            log.debug("Requesting metadata update for partition {} since the 
position {} is missing the current leader node", partition, position);
+            metadata.requestUpdate(false);
+            return Optional.empty();
+        }
+
+        // Use the preferred read replica if set, otherwise the partition's 
leader
+        Node node = selectReadReplica(partition, leaderOpt.get(), 
currentTimeMs);
+        return Optional.of(node);
+    }
+
+    /**
+     * Returns the set of IDs for {@link Node}s to which fetch requests should 
<em>not</em> be sent.
+     *
+     * <p>
+     * When a partition has buffered data in {@link FetchBuffer}, that means 
that at some point in the <em>past</em>,
+     * the following steps occurred:
+     *
+     * <ol>
+     *     <li>The client submitted a fetch request to the partition's 
leader</li>
+     *     <li>The leader responded with data</li>
+     *     <li>The client received a response from the leader and stored that 
data in memory</li>
+     * </ol>
+     *
+     * But it's possible that at the <em>current</em> point in time, that same 
partition might not be in a fetchable
+     * state. For example:
+     *
+     * <ul>
+     *     <li>
+     *         The partition is no longer assigned to the client. This also 
includes when the partition assignment
+     *         is either {@link SubscriptionState#markPendingRevocation(Set) 
pending revocation} or
+     *         {@link 
SubscriptionState#markPendingOnAssignedCallback(Collection, boolean) pending 
assignment}.
+     *     </li>
+     *     <li>
+     *         The client {@link Consumer#pause(Collection) paused} the 
partition. A paused partition remains in
+     *         the fetch buffer, because {@link 
FetchCollector#collectFetch(FetchBuffer)} explicitly skips over
+     *         paused partitions and does not return them to the user.
+     *     </li>
+     *     <li>
+     *         The partition does not have a valid position on the client. 
This could be due to the partition
+     *         awaiting validation or awaiting reset.
+     *     </li>
+     * </ul>
+     *
+     * For those reasons, a partition that was <em>previously</em> in a 
fetchable state might not <em>currently</em>
+     * be in a fetchable state.
+     * </p>
+     *
+     * <p>
+     * Here's why this is important—in a production system, a given leader 
node serves as a leader for many partitions.
+     * From the client's perspective, it's possible that a node has a mix of 
both fetchable and unfetchable partitions.
+     * When the client determines which nodes to skip and which to fetch from, 
it's important that unfetchable
+     * partitions don't block fetchable partitions from being fetched.
+     * </p>
+     *
+     * <p>
+     * So, when it's determined that a buffered partition is not in a 
fetchable state, it should be skipped.
+     * Otherwise, its node would end up in the set of nodes with buffered data 
and no fetch would be requested.
+     * </p>
+     *
+     * @param partitions Buffered partitions
+     * @param currentTimeMs Current timestamp
+     *
+     * @return Set of zero or more IDs for leader nodes of buffered partitions
+     */
+    private Set<Integer> bufferedNodes(Set<TopicPartition> partitions, long 
currentTimeMs) {
+        Set<Integer> ids = new HashSet<>();
+
+        for (TopicPartition partition : partitions) {
+            if (!subscriptions.isFetchable(partition))
+                continue;
+
+            SubscriptionState.FetchPosition position = 
positionForPartition(partition);
+            Optional<Node> nodeOpt = maybeNodeForPosition(partition, position, 
currentTimeMs);
+            nodeOpt.ifPresent(node -> ids.add(node.id()));
+        }
+
+        return ids;
+    }
+
     // Visible for testing
     protected FetchSessionHandler sessionHandler(int node) {
         return sessionHandlers.get(node);
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetchRequestManagerTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetchRequestManagerTest.java
index 3cdc0ac4845..0e282eca832 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetchRequestManagerTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetchRequestManagerTest.java
@@ -24,6 +24,7 @@ import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.MetadataRecoveryStrategy;
 import org.apache.kafka.clients.MockClient;
 import org.apache.kafka.clients.NetworkClient;
+import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.ConsumerConfig;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.clients.consumer.OffsetOutOfRangeException;
@@ -48,9 +49,6 @@ import org.apache.kafka.common.header.internals.RecordHeader;
 import org.apache.kafka.common.internals.ClusterResourceListeners;
 import org.apache.kafka.common.message.ApiMessageType;
 import org.apache.kafka.common.message.FetchResponseData;
-import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData;
-import 
org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset;
-import 
org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.OffsetForLeaderTopicResult;
 import org.apache.kafka.common.metrics.KafkaMetric;
 import org.apache.kafka.common.metrics.MetricConfig;
 import org.apache.kafka.common.metrics.Metrics;
@@ -75,7 +73,6 @@ import org.apache.kafka.common.requests.FetchRequest;
 import org.apache.kafka.common.requests.FetchRequest.PartitionData;
 import org.apache.kafka.common.requests.FetchResponse;
 import org.apache.kafka.common.requests.MetadataResponse;
-import org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse;
 import org.apache.kafka.common.requests.RequestTestUtils;
 import org.apache.kafka.common.serialization.ByteArrayDeserializer;
 import org.apache.kafka.common.serialization.BytesDeserializer;
@@ -109,6 +106,7 @@ import java.nio.charset.StandardCharsets;
 import java.time.Duration;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -116,11 +114,13 @@ import java.util.Iterator;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Optional;
 import java.util.Properties;
 import java.util.Set;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.Future;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
@@ -211,11 +211,15 @@ public class FetchRequestManagerTest {
     }
 
     private void assignFromUser(Set<TopicPartition> partitions) {
+        assignFromUser(partitions, 1);
+    }
+
+    private void assignFromUser(Set<TopicPartition> partitions, int numNodes) {
         subscriptions.assignFromUser(partitions);
-        client.updateMetadata(initialUpdateResponse);
+        client.updateMetadata(RequestTestUtils.metadataUpdateWithIds(numNodes, 
singletonMap(topicName, 4), topicIds));
 
         // A dummy metadata update to ensure valid leader epoch.
-        
metadata.updateWithCurrentRequestVersion(RequestTestUtils.metadataUpdateWithIds("dummy",
 1,
+        
metadata.updateWithCurrentRequestVersion(RequestTestUtils.metadataUpdateWithIds("dummy",
 numNodes,
                 Collections.emptyMap(), singletonMap(topicName, 4),
                 tp -> validLeaderEpoch, topicIds), false, 0L);
     }
@@ -1429,7 +1433,7 @@ public class FetchRequestManagerTest {
 
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> 
fetchedRecords;
 
-        assignFromUser(Set.of(tp0, tp1));
+        assignFromUser(Set.of(tp0, tp1), 2);    // Use multiple nodes so 
partitions have different leaders
 
         // seek to tp0 and tp1 in two polls to generate 2 complete requests 
and responses
 
@@ -1461,7 +1465,7 @@ public class FetchRequestManagerTest {
     public void testFetchOnCompletedFetchesForAllPausedPartitions() {
         buildFetcher();
 
-        assignFromUser(Set.of(tp0, tp1));
+        assignFromUser(Set.of(tp0, tp1), 2);    // Use multiple nodes so 
partitions have different leaders
 
         // seek to tp0 and tp1 in two polls to generate 2 complete requests 
and responses
 
@@ -1837,7 +1841,9 @@ public class FetchRequestManagerTest {
         buildFetcher(AutoOffsetResetStrategy.NONE, new ByteArrayDeserializer(),
                 new ByteArrayDeserializer(), 2, 
IsolationLevel.READ_UNCOMMITTED);
 
-        assignFromUser(Set.of(tp0));
+        // Use multiple nodes so partitions have different leaders. tp0 is 
added here, but tp1 is also assigned
+        // about halfway down.
+        assignFromUser(Set.of(tp0), 2);
         subscriptions.seek(tp0, 1);
         assertEquals(1, sendFetches());
         Map<TopicIdPartition, FetchResponseData.PartitionData> partitions = 
new HashMap<>();
@@ -3437,21 +3443,461 @@ public class FetchRequestManagerTest {
 
     }
 
-    private OffsetsForLeaderEpochResponse prepareOffsetsForLeaderEpochResponse(
-            TopicPartition topicPartition,
-            Errors error,
-            int leaderEpoch,
-            long endOffset
-    ) {
-        OffsetForLeaderEpochResponseData data = new 
OffsetForLeaderEpochResponseData();
-        data.topics().add(new OffsetForLeaderTopicResult()
-                .setTopic(topicPartition.topic())
-                .setPartitions(Collections.singletonList(new EpochEndOffset()
-                        .setPartition(topicPartition.partition())
-                        .setErrorCode(error.code())
-                        .setLeaderEpoch(leaderEpoch)
-                        .setEndOffset(endOffset))));
-        return new OffsetsForLeaderEpochResponse(data);
+    /**
+     * This test makes several calls to {@link #sendFetches()}, and after 
each, the buffered partitions are
+     * modified to either cause (or prevent) a fetch from being requested.
+     */
+    @Test
+    public void testFetchRequestWithBufferedPartitions() {
+        buildFetcher();
+
+        // The partitions are spread across multiple nodes to ensure the 
fetcher's logic correctly handles the
+        // partition-to-node mapping.
+        int numNodes = 2;
+        Set<TopicPartition> partitions = Set.of(tp0, tp1, tp2, tp3);
+        assignFromUser(partitions, numNodes);
+
+        // Seek each partition so that it becomes eligible to fetch.
+        partitions.forEach(tp -> subscriptions.seek(tp, 0));
+
+        // Get all the nodes serving as the leader for these partitions.
+        List<Node> nodes = nodesForPartitionLeaders(partitions);
+
+        // Extract the nodes and their respective set of partitions to make 
things easier to keep track of later.
+        assertEquals(2, nodes.size());
+        Node node0 = nodes.get(0);
+        Node node1 = nodes.get(1);
+        List<TopicPartition> node0Partitions = partitionsForNode(node0, 
partitions);
+        List<TopicPartition> node1Partitions = partitionsForNode(node1, 
partitions);
+        assertEquals(2, node0Partitions.size());
+        assertEquals(2, node1Partitions.size());
+        TopicPartition node0Partition1 = node0Partitions.get(0);
+        TopicPartition node0Partition2 = node0Partitions.get(1);
+        TopicPartition node1Partition1 = node1Partitions.get(0);
+        TopicPartition node1Partition2 = node1Partitions.get(1);
+
+        // sendFetches() call #1 should issue requests to node 0 or node 1 
since neither has buffered data.
+        List<NetworkClientDelegate.UnsentRequest> call1 = 
fetcher.sendFetches();
+        assertEquals(2, call1.size());
+        assertEquals(partitions, partitionsRequested(call1));
+        assertEquals(new HashSet<>(nodes), nodesRequested(call1));
+
+        prepareFetchResponses(node0, node0Partitions, 0);
+        prepareFetchResponses(node1, node1Partitions, 0);
+        networkClientDelegate.poll(time.timer(0));
+
+        assertEquals(4, fetcher.fetchBuffer.bufferedPartitions().size());
+        collectSelectedPartition(node0Partition1, partitions);
+        node0Partitions.remove(node0Partition1);
+        assertEquals(3, fetcher.fetchBuffer.bufferedPartitions().size());
+
+        // sendFetches() call #2 shouldn't issue requests to either node 0 or 
node 1 since they both have buffered data.
+        List<NetworkClientDelegate.UnsentRequest> call2 = 
fetcher.sendFetches();
+        assertEquals(0, call2.size());
+
+        networkClientDelegate.poll(time.timer(0));
+        collectSelectedPartition(node1Partition1, partitions);
+        node1Partitions.remove(node1Partition1);
+        assertEquals(2, fetcher.fetchBuffer.bufferedPartitions().size());
+
+        // sendFetches() call #3 shouldn't issue requests to either node 0 or 
node 1 since they both have buffered data.
+        List<NetworkClientDelegate.UnsentRequest> call3 = 
fetcher.sendFetches();
+        assertEquals(0, call3.size());
+
+        networkClientDelegate.poll(time.timer(0));
+        collectSelectedPartition(node0Partition2, partitions);
+        node0Partitions.remove(node0Partition2);
+        assertEquals(1, fetcher.fetchBuffer.bufferedPartitions().size());
+
+        // Validate that all of node 0's partitions have all been collected.
+        assertTrue(node0Partitions.isEmpty());
+
+        // Reset the list of partitions for node 0 so the next fetch pass 
requests data.
+        node0Partitions = partitionsForNode(node0, partitions);
+
+        // sendFetches() call #4 should issue a request to node 0 since its 
buffered data was collected.
+        List<NetworkClientDelegate.UnsentRequest> call4 = 
fetcher.sendFetches();
+        assertEquals(1, call4.size());
+        assertEquals(Set.of(node0Partition1, node0Partition2), 
partitionsRequested(call4));
+        assertEquals(Set.of(node0), nodesRequested(call4));
+
+        prepareFetchResponses(node0, node0Partitions, 10);
+        networkClientDelegate.poll(time.timer(0));
+
+        collectSelectedPartition(node1Partition2, partitions);
+        node1Partitions.remove(node1Partition2);
+        assertEquals(2, fetcher.fetchBuffer.bufferedPartitions().size());
+
+        // Node 1's partitions have likewise all been collected, so validate 
that.
+        assertTrue(node1Partitions.isEmpty());
+
+        // Again, reset the list of partitions, this time for node 1, so the 
next fetch pass requests data.
+        node1Partitions = partitionsForNode(node1, partitions);
+
+        // sendFetches() call #5 should issue a request to node 1 since its 
buffered data was collected.
+        List<NetworkClientDelegate.UnsentRequest> call5 = 
fetcher.sendFetches();
+        assertEquals(1, call5.size());
+        assertEquals(Set.of(node1Partition1, node1Partition2), 
partitionsRequested(call5));
+        assertEquals(Set.of(node1), nodesRequested(call5));
+
+        prepareFetchResponses(node1, node1Partitions, 10);
+        networkClientDelegate.poll(time.timer(0));
+        assertEquals(4, fetcher.fetchBuffer.bufferedPartitions().size());
+
+        // Collect all the records and make sure they include all the 
partitions, and validate that there is no data
+        // remaining in the fetch buffer.
+        assertEquals(partitions, fetchRecords().keySet());
+        assertEquals(0, fetcher.fetchBuffer.bufferedPartitions().size());
+
+        // sendFetches() call #6 should issue a request to nodes 0 and 1 since 
its buffered data was collected.
+        List<NetworkClientDelegate.UnsentRequest> call6 = 
fetcher.sendFetches();
+        assertEquals(2, call6.size());
+        assertEquals(partitions, partitionsRequested(call6));
+        assertEquals(new HashSet<>(nodes), nodesRequested(call6));
+
+        prepareFetchResponses(node0, node0Partitions, 20);
+        prepareFetchResponses(node1, node1Partitions, 20);
+        networkClientDelegate.poll(time.timer(0));
+        assertEquals(4, fetcher.fetchBuffer.bufferedPartitions().size());
+
+        // Just for completeness, collect all the records and make sure they 
include all the partitions, and validate
+        // that there is no data remaining in the fetch buffer.
+        assertEquals(partitions, fetchRecords().keySet());
+        assertEquals(0, fetcher.fetchBuffer.bufferedPartitions().size());
+    }
+
+    @Test
+    public void testFetchRequestWithBufferedPartitionNotAssigned() {
+        buildFetcher();
+
+        // The partitions are spread across multiple nodes to ensure the 
fetcher's logic correctly handles the
+        // partition-to-node mapping.
+        int numNodes = 2;
+        Set<TopicPartition> partitions = Set.of(tp0, tp1, tp2, tp3);
+        assignFromUser(partitions, numNodes);
+
+        // Seek each partition so that it becomes eligible to fetch.
+        partitions.forEach(tp -> subscriptions.seek(tp, 0));
+
+        // Get all the nodes serving as the leader for these partitions.
+        List<Node> nodes = nodesForPartitionLeaders(partitions);
+
+        // Extract the nodes and their respective set of partitions to make 
things easier to keep track of later.
+        assertEquals(2, nodes.size());
+        Node node0 = nodes.get(0);
+        Node node1 = nodes.get(1);
+        List<TopicPartition> node0Partitions = partitionsForNode(node0, 
partitions);
+        List<TopicPartition> node1Partitions = partitionsForNode(node1, 
partitions);
+        assertEquals(2, node0Partitions.size());
+        assertEquals(2, node1Partitions.size());
+        TopicPartition node0Partition1 = node0Partitions.get(0);
+        TopicPartition node0Partition2 = node0Partitions.get(1);
+        TopicPartition node1Partition1 = node1Partitions.get(0);
+        TopicPartition node1Partition2 = node1Partitions.get(1);
+
+        // sendFetches() call #1 should issue requests to node 0 or node 1 
since neither has buffered data.
+        List<NetworkClientDelegate.UnsentRequest> call1 = 
fetcher.sendFetches();
+        assertEquals(2, call1.size());
+        assertEquals(partitions, partitionsRequested(call1));
+        assertEquals(new HashSet<>(nodes), nodesRequested(call1));
+
+        prepareFetchResponses(node0, node0Partitions, 0);
+        prepareFetchResponses(node1, node1Partitions, 0);
+        networkClientDelegate.poll(time.timer(0));
+
+        // Collect node0Partition1 so that it doesn't have anything in the 
fetch buffer.
+        assertEquals(4, fetcher.fetchBuffer.bufferedPartitions().size());
+        collectSelectedPartition(node0Partition1, partitions);
+        assertEquals(3, fetcher.fetchBuffer.bufferedPartitions().size());
+
+        // Exclude node0Partition2 (the remaining buffered partition for node 
0) when updating the assigned partitions
+        // to cause it to become unassigned.
+        subscriptions.assignFromUser(Set.of(
+            node0Partition1,
+            // node0Partition2,         // Intentionally omit this partition 
so that it is unassigned
+            node1Partition1,
+            node1Partition2
+        ));
+
+        // node0Partition1 (the collected partition) should have a retrievable 
position, but node0Partition2
+        // (the unassigned position) should throw an error when attempting to 
retrieve its position.
+        assertDoesNotThrow(() -> subscriptions.position(node0Partition1));
+        assertThrows(IllegalStateException.class, () -> 
subscriptions.position(node0Partition2));
+
+        // sendFetches() call #2 should issue a request to node 0 because the 
first partition in node 0 was collected
+        // (and its buffer removed) and the second partition for node 0 was 
unassigned. As a result, there are now no
+        // *assigned* partitions for node 0 that are buffered.
+        List<NetworkClientDelegate.UnsentRequest> call2 = 
fetcher.sendFetches();
+        assertEquals(1, call2.size());
+        assertEquals(Set.of(node0Partition1), partitionsRequested(call2));
+        assertEquals(Set.of(node0), nodesRequested(call2));
+    }
+
+    @Test
+    public void testFetchRequestWithBufferedPartitionMissingLeader() {
+        buildFetcher();
+
+        Set<TopicPartition> partitions = Set.of(tp0, tp1);
+        assignFromUser(partitions);
+
+        // Seek each partition so that it becomes eligible to fetch.
+        partitions.forEach(tp -> subscriptions.seek(tp, 0));
+
+        Node leader = metadata.fetch().leaderFor(tp0);
+        assertNotNull(leader);
+
+        // sendFetches() call #1 should issue a request since there's no 
buffered data.
+        List<NetworkClientDelegate.UnsentRequest> call1 = 
fetcher.sendFetches();
+        assertEquals(1, call1.size());
+        assertEquals(Set.of(tp0, tp1), partitionsRequested(call1));
+        assertEquals(Set.of(leader), nodesRequested(call1));
+
+        prepareFetchResponses(leader, Set.of(tp0, tp1), 0);
+        networkClientDelegate.poll(time.timer(0));
+
+        // Per the fetch response, data for both of the partitions are in the 
fetch buffer.
+        assertTrue(fetcher.fetchBuffer.bufferedPartitions().contains(tp0));
+        assertTrue(fetcher.fetchBuffer.bufferedPartitions().contains(tp1));
+
+        // Collect the first partition (tp0) which will remove it from the 
fetch buffer.
+        collectSelectedPartition(tp0, partitions);
+
+        // Since tp0 was collected, it's not in the fetch buffer, but tp1 
remains in the fetch buffer.
+        assertFalse(fetcher.fetchBuffer.bufferedPartitions().contains(tp0));
+        assertTrue(fetcher.fetchBuffer.bufferedPartitions().contains(tp1));
+
+        // Overwrite tp1's position with an empty leader, but verify that it 
is still buffered. Having a leaderless,
+        // buffered partition is key to triggering the test case.
+        subscriptions.position(tp1, new SubscriptionState.FetchPosition(
+            0,
+            Optional.empty(),
+            Metadata.LeaderAndEpoch.noLeaderOrEpoch()
+        ));
+        assertTrue(fetcher.fetchBuffer.bufferedPartitions().contains(tp1));
+
+        // Validate the state of the collected partition (tp0) and leaderless 
partition (tp1) before sending the
+        // fetch request.
+        
assertTrue(subscriptions.position(tp0).currentLeader.leader.isPresent());
+        
assertFalse(subscriptions.position(tp1).currentLeader.leader.isPresent());
+
+        // sendFetches() call #2 should issue a fetch request because it has 
no buffered partitions:
+        //
+        // - tp0 was collected and thus not in the fetch buffer
+        // - tp1, while still in the fetch buffer, is leaderless
+        //
+        // As a result, there are now effectively no buffered partitions for 
which there is a leader.
+        List<NetworkClientDelegate.UnsentRequest> call2 = 
fetcher.sendFetches();
+        assertEquals(1, call2.size());
+        assertEquals(Set.of(tp0), partitionsRequested(call2));
+        assertEquals(Set.of(leader), nodesRequested(call2));
+    }
+
+    @Test
+    public void testFetchRequestWithBufferedPartitionMissingPosition() {
+        buildFetcher();
+
+        // The partitions are spread across multiple nodes to ensure the 
fetcher's logic correctly handles the
+        // partition-to-node mapping.
+        int numNodes = 2;
+        Set<TopicPartition> partitions = Set.of(tp0, tp1, tp2, tp3);
+        assignFromUser(partitions, numNodes);
+
+        // Seek each partition so that it becomes eligible to fetch.
+        partitions.forEach(tp -> subscriptions.seek(tp, 0));
+
+        // Get all the nodes serving as the leader for these partitions.
+        List<Node> nodes = nodesForPartitionLeaders(partitions);
+
+        // Extract the nodes and their respective set of partitions to make 
things easier to keep track of later.
+        assertEquals(2, nodes.size());
+        Node node0 = nodes.get(0);
+        Node node1 = nodes.get(1);
+        List<TopicPartition> node0Partitions = partitionsForNode(node0, 
partitions);
+        List<TopicPartition> node1Partitions = partitionsForNode(node1, 
partitions);
+        assertEquals(2, node0Partitions.size());
+        assertEquals(2, node1Partitions.size());
+        TopicPartition node0Partition1 = node0Partitions.get(0);
+        TopicPartition node0Partition2 = node0Partitions.get(1);
+
+        // sendFetches() call #1 should issue requests to node 0 or node 1 
since neither has buffered data.
+        List<NetworkClientDelegate.UnsentRequest> call1 = 
fetcher.sendFetches();
+        assertEquals(2, call1.size());
+        assertEquals(partitions, partitionsRequested(call1));
+        assertEquals(new HashSet<>(nodes), nodesRequested(call1));
+
+        prepareFetchResponses(node0, node0Partitions, 0);
+        prepareFetchResponses(node1, node1Partitions, 0);
+        networkClientDelegate.poll(time.timer(0));
+
+        // Collect node 0's first partition (node0Partition1) which will 
remove it from the fetch buffer.
+        assertEquals(4, fetcher.fetchBuffer.bufferedPartitions().size());
+        collectSelectedPartition(node0Partition1, partitions);
+        assertEquals(3, fetcher.fetchBuffer.bufferedPartitions().size());
+
+        // Overwrite node0Partition2's position with an empty leader to 
trigger the test case.
+        subscriptions.position(node0Partition2, null);
+
+        // Confirm that calling SubscriptionState.position() succeeds for a 
leaderless partition. While it shouldn't
+        // throw an exception, it should return a null position.
+        SubscriptionState.FetchPosition position = assertDoesNotThrow(() -> 
subscriptions.position(node0Partition2));
+        assertNull(position);
+
+        // sendFetches() call #2 will now fail to send any requests as we have 
an invalid position in the assignment.
+        // The Consumer.poll() API will throw an IllegalStateException to the 
user.
+        Future<Void> future = fetcher.createFetchRequests();
+        List<NetworkClientDelegate.UnsentRequest> call2 = 
fetcher.sendFetches();
+        assertEquals(0, call2.size());
+        assertFutureThrows(future, IllegalStateException.class);
+    }
+
+    @Test
+    public void testFetchRequestWithBufferedPartitionPaused() {
+        testFetchRequestWithBufferedPartitionUnfetchable(tp -> 
subscriptions.pause(tp));
+    }
+
+    @Test
+    public void testFetchRequestWithBufferedPartitionPendingRevocation() {
+        testFetchRequestWithBufferedPartitionUnfetchable(tp -> 
subscriptions.markPendingRevocation(Set.of(tp)));
+    }
+
+    @Test
+    public void testFetchRequestWithBufferedPartitionPendingAssignment() {
+        testFetchRequestWithBufferedPartitionUnfetchable(tp -> 
subscriptions.markPendingOnAssignedCallback(Set.of(tp), true));
+    }
+
+    @Test
+    public void testFetchRequestWithBufferedPartitionResetOffset() {
+        testFetchRequestWithBufferedPartitionUnfetchable(tp -> 
subscriptions.requestOffsetReset(tp));
+    }
+
+    private void 
testFetchRequestWithBufferedPartitionUnfetchable(java.util.function.Consumer<TopicPartition>
 partitionMutator) {
+        buildFetcher();
+
+        Set<TopicPartition> partitions = Set.of(tp0, tp1);
+        assignFromUser(partitions);
+
+        // Seek each partition so that it becomes eligible to fetch.
+        partitions.forEach(tp -> subscriptions.seek(tp, 0));
+
+        // sendFetches() call #1 should issue a request since there's no 
buffered data.
+        List<NetworkClientDelegate.UnsentRequest> call1 = 
fetcher.sendFetches();
+        assertEquals(1, call1.size());
+        assertEquals(Set.of(tp0, tp1), partitionsRequested(call1));
+        prepareFetchResponses(metadata.fetch().leaderFor(tp0), Set.of(tp0, 
tp1), 0);
+        networkClientDelegate.poll(time.timer(0));
+
+        // Per the fetch response, data for both of the partitions are in the 
fetch buffer.
+        assertTrue(fetcher.fetchBuffer.bufferedPartitions().contains(tp0));
+        assertTrue(fetcher.fetchBuffer.bufferedPartitions().contains(tp1));
+
+        // Collect the first partition (tp0) which will remove it from the 
fetch buffer.
+        collectSelectedPartition(tp0, partitions);
+        assertFalse(fetcher.fetchBuffer.bufferedPartitions().contains(tp0));
+
+        // Mutate tp1 to make it unfetchable, but verify that it is still 
buffered. Having a buffered partition that
+        // is also unfetchable is key to triggering the test case.
+        partitionMutator.accept(tp1);
+        assertTrue(fetcher.fetchBuffer.bufferedPartitions().contains(tp1));
+
+        // sendFetches() call #2 should issue a fetch request because it has 
no buffered partitions:
+        //
+        // - tp0 was collected and thus not in the fetch buffer
+        // - tp1, while still in the fetch buffer, is unfetchable and should 
be ignored
+        List<NetworkClientDelegate.UnsentRequest> call2 = 
fetcher.sendFetches();
+        assertEquals(1, call2.size());
+        assertEquals(Set.of(tp0), partitionsRequested(call2));
+    }
+
+    /**
+     * For each partition given, return the set of nodes that represent the 
partition's leader using
+     * {@link Cluster#leaderFor(TopicPartition)}.
+     */
+    private List<Node> nodesForPartitionLeaders(Set<TopicPartition> 
partitions) {
+        Cluster cluster = metadata.fetch();
+
+        return partitions.stream()
+            .map(cluster::leaderFor)
+            .filter(Objects::nonNull)
+            .distinct()
+            .collect(Collectors.toList());
+    }
+
+    /**
+     * For the given set of partitions, filter the partitions to be those 
where the partition's leader node
+     * (via {@link Cluster#leaderFor(TopicPartition)}) matches the given node.
+     */
+    private List<TopicPartition> partitionsForNode(Node node, 
Set<TopicPartition> partitions) {
+        Cluster cluster = metadata.fetch();
+
+        return partitions.stream()
+            .filter(tp -> node.equals(cluster.leaderFor(tp)))
+            .collect(Collectors.toList());
+    }
+
+    /**
+     * Creates 10 dummy records starting at the given offset for each given 
partition and directs each response to the
+     * given node.
+     */
+    private void prepareFetchResponses(Node node, Collection<TopicPartition> 
partitions, int offset) {
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
partitionDataMap = new LinkedHashMap<>();
+
+        partitions.forEach(tp -> {
+            MemoryRecords records = buildRecords(offset, 10, 1);
+            FetchResponseData.PartitionData partitionData = new 
FetchResponseData.PartitionData()
+                .setPartitionIndex(tp.partition())
+                .setHighWatermark(100)
+                .setRecords(records);
+            partitionDataMap.put(new TopicIdPartition(topicId, tp), 
partitionData);
+        });
+
+        client.prepareResponseFrom(
+            FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, 
partitionDataMap),
+            node
+        );
+    }
+
+    /**
+     * Invokes {@link #collectFetch()}, but before doing so it {@link 
Consumer#pause(Collection) pauses} all the
+     * partitions in the given set of partitions <em>except</em> for {@code 
partition}. This is done so that only
+     * that partition will be collected. Once the collection has been 
performed, the previously-paused partitions
+     * are then {@link Consumer#resume(Collection) resumed}.
+     */
+    private void collectSelectedPartition(TopicPartition partition, 
Set<TopicPartition> partitions) {
+        // Pause any remaining partitions so that when fetchRecords() is 
called, only the records for the
+        // "fetched" partition are collected, leaving the remaining in the 
fetch buffer.
+        Set<TopicPartition> pausedPartitions = partitions.stream()
+            .filter(tp -> !tp.equals(partition))
+            .collect(Collectors.toSet());
+
+        // Fetch the records, which should be just for the expected topic 
partition since the others were paused.
+        pausedPartitions.forEach(tp -> subscriptions.pause(tp));
+        fetchRecords();
+        pausedPartitions.forEach(tp -> subscriptions.resume(tp));
+    }
+
+    /**
+     * Returns the unique set of partitions that were included in the given 
requests.
+     */
+    private Set<TopicPartition> 
partitionsRequested(List<NetworkClientDelegate.UnsentRequest> requests) {
+        return requests.stream()
+            .map(NetworkClientDelegate.UnsentRequest::requestBuilder)
+            .filter(FetchRequest.Builder.class::isInstance)
+            .map(FetchRequest.Builder.class::cast)
+            .map(FetchRequest.Builder::fetchData)
+            .map(Map::keySet)
+            .flatMap(Set::stream)
+            .collect(Collectors.toSet());
+    }
+
+    /**
+     * Returns the unique set of nodes to which fetch requests were sent.
+     */
+    private Set<Node> nodesRequested(List<NetworkClientDelegate.UnsentRequest> 
requests) {
+        return requests.stream()
+            .map(NetworkClientDelegate.UnsentRequest::node)
+            .filter(Optional::isPresent)
+            .map(Optional::get)
+            .collect(Collectors.toSet());
     }
 
     private FetchResponse fetchResponseWithTopLevelError(TopicIdPartition tp, 
Errors error, int throttleTime) {
@@ -3742,6 +4188,14 @@ public class FetchRequestManagerTest {
             return pollResult.unsentRequests.size();
         }
 
+        private List<NetworkClientDelegate.UnsentRequest> sendFetches() {
+            offsetFetcher.validatePositionsOnMetadataChange();
+            createFetchRequests();
+            NetworkClientDelegate.PollResult pollResult = 
poll(time.milliseconds());
+            networkClientDelegate.addAll(pollResult.unsentRequests);
+            return pollResult.unsentRequests;
+        }
+
         private void 
clearBufferedDataForUnassignedPartitions(Set<TopicPartition> partitions) {
             fetchBuffer.retainAll(partitions);
         }
@@ -3749,7 +4203,7 @@ public class FetchRequestManagerTest {
 
     private class TestableNetworkClientDelegate extends NetworkClientDelegate {
 
-        private final Logger log = 
LoggerFactory.getLogger(NetworkClientDelegate.class);
+        private final Logger log = 
LoggerFactory.getLogger(TestableNetworkClientDelegate.class);
         private final ConcurrentLinkedQueue<Node> pendingDisconnects = new 
ConcurrentLinkedQueue<>();
 
         public TestableNetworkClientDelegate(Time time,
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
index 856c2b9478a..b24475300c8 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
@@ -209,11 +209,15 @@ public class FetcherTest {
     }
 
     private void assignFromUser(Set<TopicPartition> partitions) {
+        assignFromUser(partitions, 1);
+    }
+
+    private void assignFromUser(Set<TopicPartition> partitions, int numNodes) {
         subscriptions.assignFromUser(partitions);
-        client.updateMetadata(initialUpdateResponse);
+        client.updateMetadata(RequestTestUtils.metadataUpdateWithIds(numNodes, 
singletonMap(topicName, 4), topicIds));
 
         // A dummy metadata update to ensure valid leader epoch.
-        
metadata.updateWithCurrentRequestVersion(RequestTestUtils.metadataUpdateWithIds("dummy",
 1,
+        
metadata.updateWithCurrentRequestVersion(RequestTestUtils.metadataUpdateWithIds("dummy",
 numNodes,
             Collections.emptyMap(), singletonMap(topicName, 4),
             tp -> validLeaderEpoch, topicIds), false, 0L);
     }
@@ -1152,16 +1156,17 @@ public class FetcherTest {
         Set<TopicPartition> tps = new HashSet<>();
         tps.add(tp0);
         tps.add(tp1);
-        assignFromUser(tps);
+        assignFromUser(tps, 2);                 // Use multiple nodes so 
partitions have different leaders
         subscriptions.seek(tp0, 1);
         subscriptions.seek(tp1, 6);
 
-        client.prepareResponse(fetchResponse2(tidp0, records, 100L, tidp1, 
moreRecords, 100L));
+        client.prepareResponse(fullFetchResponse(tidp0, records, Errors.NONE, 
100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp1, moreRecords, 
Errors.NONE, 100L, 0));
         client.prepareResponse(fullFetchResponse(tidp0, emptyRecords, 
Errors.NONE, 100L, 0));
 
-        // Send fetch request because we do not have pending fetch responses 
to process.
-        // The first fetch response will return 3 records for tp0 and 3 more 
for tp1.
-        assertEquals(1, sendFetches());
+        // Send two fetch requests (one to each node) because we do not have 
pending fetch responses to process.
+        // The fetch responses will return 3 records for tp0 and 3 more for 
tp1.
+        assertEquals(2, sendFetches());
         // The poll returns 2 records from one of the topic-partitions 
(non-deterministic).
         // This leaves 1 record pending from that topic-partition, and the 
remaining 3 from the other.
         pollAndValidateMaxPollRecordsNotExceeded(maxPollRecords);
@@ -1428,7 +1433,7 @@ public class FetcherTest {
 
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> 
fetchedRecords;
 
-        assignFromUser(Set.of(tp0, tp1));
+        assignFromUser(Set.of(tp0, tp1), 2);    // Use multiple nodes so 
partitions have different leaders
 
         // seek to tp0 and tp1 in two polls to generate 2 complete requests 
and responses
 
@@ -1460,7 +1465,7 @@ public class FetcherTest {
     public void testFetchOnCompletedFetchesForAllPausedPartitions() {
         buildFetcher();
 
-        assignFromUser(Set.of(tp0, tp1));
+        assignFromUser(Set.of(tp0, tp1), 2);    // Use multiple nodes so 
partitions have different leaders
 
         // seek to tp0 and tp1 in two polls to generate 2 complete requests 
and responses
 
@@ -1823,7 +1828,9 @@ public class FetcherTest {
         buildFetcher(AutoOffsetResetStrategy.NONE, new ByteArrayDeserializer(),
                 new ByteArrayDeserializer(), 2, 
IsolationLevel.READ_UNCOMMITTED);
 
-        assignFromUser(Set.of(tp0));
+        // Use multiple nodes so partitions have different leaders. tp0 is 
added here, but tp1 is also assigned
+        // about halfway down.
+        assignFromUser(Set.of(tp0), 2);
         subscriptions.seek(tp0, 1);
         assertEquals(1, sendFetches());
         Map<TopicIdPartition, FetchResponseData.PartitionData> partitions = 
new HashMap<>();
@@ -3757,28 +3764,6 @@ public class FetcherTest {
         return FetchResponse.of(Errors.NONE, throttleTime, INVALID_SESSION_ID, 
new LinkedHashMap<>(partitions));
     }
 
-    private FetchResponse fetchResponse2(TopicIdPartition tp1, MemoryRecords 
records1, long hw1,
-                                         TopicIdPartition tp2, MemoryRecords 
records2, long hw2) {
-        Map<TopicIdPartition, FetchResponseData.PartitionData> partitions = 
new HashMap<>();
-        partitions.put(tp1,
-                new FetchResponseData.PartitionData()
-                        .setPartitionIndex(tp1.topicPartition().partition())
-                        .setErrorCode(Errors.NONE.code())
-                        .setHighWatermark(hw1)
-                        
.setLastStableOffset(FetchResponse.INVALID_LAST_STABLE_OFFSET)
-                        .setLogStartOffset(0)
-                        .setRecords(records1));
-        partitions.put(tp2,
-                new FetchResponseData.PartitionData()
-                        .setPartitionIndex(tp2.topicPartition().partition())
-                        .setErrorCode(Errors.NONE.code())
-                        .setHighWatermark(hw2)
-                        
.setLastStableOffset(FetchResponse.INVALID_LAST_STABLE_OFFSET)
-                        .setLogStartOffset(0)
-                        .setRecords(records2));
-        return FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, new 
LinkedHashMap<>(partitions));
-    }
-
     /**
      * Assert that the {@link Fetcher#collectFetch() latest fetch} does not 
contain any
      * {@link Fetch#records() user-visible records}, did not


Reply via email to