This is an automated email from the ASF dual-hosted git repository.

schofielaj pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 8209af2a237 KAFKA-20064: Made PartitionLeaderCache thread safe. 
(#21335)
8209af2a237 is described below

commit 8209af2a23701c920e0eb9fe7933f6054d90edf0
Author: Nikita Shupletsov <[email protected]>
AuthorDate: Thu Jan 29 12:06:49 2026 -0800

    KAFKA-20064: Made PartitionLeaderCache thread safe. (#21335)
    
    * Introduced a new class - PartitionLeaderCache
    * Changed the usage of the cache to make calls atomic(e.g. getting
    cached and non-cached values as one call instead of two, deleting cached
    values as one call, not one by one)
    * Added an integration test that tests the concurrent access of the
    cache
    
    Reviewers: Andrew Schofield <[email protected]>
---
 .../admin/ConcurrentListOffsetsRequestTest.java    | 199 +++++++++++++++++++++
 .../kafka/clients/admin/KafkaAdminClient.java      |   5 +-
 .../admin/internals/AbortTransactionHandler.java   |   3 +-
 .../clients/admin/internals/AdminApiDriver.java    |   9 +-
 .../clients/admin/internals/AdminApiFuture.java    |  23 ++-
 .../admin/internals/DeleteRecordsHandler.java      |   2 +-
 .../admin/internals/DescribeProducersHandler.java  |   2 +-
 .../admin/internals/ListOffsetsHandler.java        |   2 +-
 .../admin/internals/PartitionLeaderCache.java      |  54 ++++++
 .../admin/internals/PartitionLeaderStrategy.java   |  29 +--
 .../PartitionLeaderStrategyIntegrationTest.java    |  46 ++---
 11 files changed, 310 insertions(+), 64 deletions(-)

diff --git 
a/clients/clients-integration-tests/src/test/java/org/apache/kafka/clients/admin/ConcurrentListOffsetsRequestTest.java
 
b/clients/clients-integration-tests/src/test/java/org/apache/kafka/clients/admin/ConcurrentListOffsetsRequestTest.java
new file mode 100644
index 00000000000..8facea1de34
--- /dev/null
+++ 
b/clients/clients-integration-tests/src/test/java/org/apache/kafka/clients/admin/ConcurrentListOffsetsRequestTest.java
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.clients.admin;
+
+import org.apache.kafka.clients.CommonClientConfigs;
+import org.apache.kafka.clients.DefaultHostResolver;
+import org.apache.kafka.clients.NetworkClient;
+import org.apache.kafka.clients.admin.internals.PartitionLeaderCache;
+import org.apache.kafka.common.IsolationLevel;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.errors.TimeoutException;
+import org.apache.kafka.common.test.ClusterInstance;
+import org.apache.kafka.common.test.api.ClusterTest;
+import org.apache.kafka.common.test.api.ClusterTestDefaults;
+import org.apache.kafka.common.test.api.Type;
+import org.apache.kafka.common.utils.Utils;
+import org.apache.kafka.test.TestUtils;
+
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+
+import java.net.InetAddress;
+import java.net.UnknownHostException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+import static org.junit.jupiter.api.Assertions.assertInstanceOf;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+@ClusterTestDefaults(
+        types = {Type.KRAFT},
+        brokers = 3
+)
+public class ConcurrentListOffsetsRequestTest {
+    private static final String TOPIC = "topic";
+    private static final short REPLICAS = 1;
+    private static final int PARTITION = 2;
+    private static final int TIMEOUT = 1000;
+    private final ClusterInstance clusterInstance;
+    private Admin adminClient;
+    private NetworkClient networkClient;
+    private final AtomicBoolean injectHostResolverError = new 
AtomicBoolean(false);
+
+    ConcurrentListOffsetsRequestTest(ClusterInstance clusterInstance) {
+        this.clusterInstance = clusterInstance;
+    }
+
+    @BeforeEach
+    public void setup() throws Exception {
+        clusterInstance.waitForReadyBrokers();
+        clusterInstance.createTopic(TOPIC, PARTITION, REPLICAS);
+        Map<String, Object> props = Map.of(
+                "default.api.timeout.ms", TIMEOUT,
+                "request.timeout.ms", TIMEOUT,
+                CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, 
clusterInstance.bootstrapServers());
+        adminClient = KafkaAdminClient.createInternal(new 
AdminClientConfig(clusterInstance.setClientSaslConfig(props), true),
+                null, new TestHostResolver());
+
+        networkClient = TestUtils.fieldValue(adminClient, 
KafkaAdminClient.class, "client");
+    }
+
+    @AfterEach
+    public void teardown() {
+        Utils.closeQuietly(adminClient, "ListOffsetsAdminClient");
+    }
+
+    @ClusterTest
+    public void correctlyHandleConcurrentModificationOfPartitionLeaderCache() 
throws Exception {
+        // making one request to prepopulate the partition leader cache so we 
have something to delete later
+        listAllOffsets().all().get(TIMEOUT * 2, TimeUnit.SECONDS);
+
+        final CountDownLatch invalidationLatch = new CountDownLatch(1);
+        // Replacing the partition leader cache in order to be able to 
synchronize the calls so that they happen in the right order to reproduce the 
issue
+        TestPartitionLeaderCache testPartitionLeaderCache = 
replacePartitionLeaderCache(invalidationLatch);
+
+        // closing the connection to the first node. not using 
clusterInstance.shutdownBroker to reduce flakiness
+        
networkClient.close(testPartitionLeaderCache.get(getTopicPartitions()).values().iterator().next().toString());
+        // as next call with try to resolve the host for the closed node, it's 
time to let it fail, which will lead to cache invalidation
+        injectHostResolverError.set(true);
+
+        // making another request(this request will face the host resolver 
error and remove the node from the cache)
+        ListOffsetsResult failInducingResult = listAllOffsets();
+        // waiting until we get to the invalidation
+        invalidationLatch.await();
+        // making another request. at this point the fail inducing request is 
waiting for this one before it deletes the keys associated with the node
+        // the TestPartitionLeaderCache class synchronizes the calls to mimic 
the race condition
+        ListOffsetsResult failingResult = listAllOffsets();
+
+        // verifying that we correctly declined the call
+        ExecutionException executionException = 
assertThrows(ExecutionException.class, () -> 
failInducingResult.all().get(TIMEOUT * 2, TimeUnit.MILLISECONDS));
+        assertInstanceOf(TimeoutException.class, 
executionException.getCause());
+
+        // verifying that we correctly declined the call
+        executionException = assertThrows(ExecutionException.class, () -> 
failingResult.all().get(TIMEOUT * 2, TimeUnit.MILLISECONDS));
+        assertInstanceOf(TimeoutException.class, 
executionException.getCause());
+    }
+
+    private TestPartitionLeaderCache 
replacePartitionLeaderCache(CountDownLatch invalidationLatch) throws Exception {
+        PartitionLeaderCache oldPartitionLeaderCache = 
TestUtils.fieldValue(adminClient, KafkaAdminClient.class, 
"partitionLeaderCache");
+
+        TestPartitionLeaderCache partitionLeaderCache = new 
TestPartitionLeaderCache(oldPartitionLeaderCache.get(getTopicPartitions()), 
invalidationLatch);
+        TestUtils.setFieldValue(adminClient, "partitionLeaderCache", 
partitionLeaderCache);
+        return partitionLeaderCache;
+    }
+
+    private ListOffsetsResult listAllOffsets() {
+        List<TopicPartition> partitions = getTopicPartitions();
+
+        Map<TopicPartition, OffsetSpec> offsetSpecMap = 
partitions.stream().collect(Collectors.toMap(Function.identity(), tp -> 
OffsetSpec.latest()));
+        return adminClient.listOffsets(offsetSpecMap, new 
ListOffsetsOptions(IsolationLevel.READ_UNCOMMITTED));
+    }
+
+    private List<TopicPartition> getTopicPartitions() {
+        List<TopicPartition> partitions = new ArrayList<>();
+        for (int i = 0; i < PARTITION; i++) {
+            partitions.add(new TopicPartition(TOPIC, i));
+        }
+        return partitions;
+    }
+
+    private static class TestPartitionLeaderCache extends PartitionLeaderCache 
{
+
+        private final AtomicInteger getCounter = new AtomicInteger(0);
+        private final CountDownLatch invalidationLatch;
+        private final CountDownLatch newRequestCheckLatch = new 
CountDownLatch(1);
+        private final CountDownLatch removeCompleteLatch = new 
CountDownLatch(1);
+
+        public TestPartitionLeaderCache(Map<TopicPartition, Integer> 
oldPartitionLeaderCache, final CountDownLatch invalidationLatch) {
+            put(oldPartitionLeaderCache);
+            this.invalidationLatch = invalidationLatch;
+        }
+
+        @Override
+        public Map<TopicPartition, Integer> get(Collection<TopicPartition> 
keys) {
+            Map<TopicPartition, Integer> result = super.get(keys);
+            // waiting for the third call: first one was to close the network 
connection, second one was from the request that invalidates the cache
+            if (getCounter.incrementAndGet() == 3) {
+                newRequestCheckLatch.countDown();
+                try {
+                    // letting the remove method proceed and actually remove 
the data
+                    removeCompleteLatch.await();
+                } catch (InterruptedException e) {
+                    throw new RuntimeException(e);
+                }
+            }
+
+            return result;
+        }
+
+        @Override
+        public void remove(Collection<TopicPartition> keys) {
+            try {
+                // letting the caller know that we've reached the invalidation 
step, and it's time to send the second request
+                invalidationLatch.countDown();
+                // waiting for the second request to reach get
+                newRequestCheckLatch.await();
+            } catch (InterruptedException e) {
+                throw new RuntimeException(e);
+            }
+            super.remove(keys);
+            // once the value removed, we are letting the get method proceed 
and return the value
+            removeCompleteLatch.countDown();
+        }
+    }
+
+    private class TestHostResolver extends DefaultHostResolver {
+
+        @Override
+        public InetAddress[] resolve(String host) throws UnknownHostException {
+            if (injectHostResolverError.get()) {
+                throw new UnknownHostException();
+            }
+            return super.resolve(host);
+        }
+    }
+}
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/admin/KafkaAdminClient.java 
b/clients/src/main/java/org/apache/kafka/clients/admin/KafkaAdminClient.java
index fa9440dfb91..a80c860de5e 100644
--- a/clients/src/main/java/org/apache/kafka/clients/admin/KafkaAdminClient.java
+++ b/clients/src/main/java/org/apache/kafka/clients/admin/KafkaAdminClient.java
@@ -63,6 +63,7 @@ import 
org.apache.kafka.clients.admin.internals.ListConsumerGroupOffsetsHandler;
 import org.apache.kafka.clients.admin.internals.ListOffsetsHandler;
 import org.apache.kafka.clients.admin.internals.ListShareGroupOffsetsHandler;
 import org.apache.kafka.clients.admin.internals.ListTransactionsHandler;
+import org.apache.kafka.clients.admin.internals.PartitionLeaderCache;
 import org.apache.kafka.clients.admin.internals.PartitionLeaderStrategy;
 import 
org.apache.kafka.clients.admin.internals.RemoveMembersFromConsumerGroupHandler;
 import org.apache.kafka.clients.consumer.OffsetAndMetadata;
@@ -407,7 +408,7 @@ public class KafkaAdminClient extends AdminClient {
     private final long retryBackoffMaxMs;
     private final ExponentialBackoff retryBackoff;
     private final MetadataRecoveryStrategy metadataRecoveryStrategy;
-    private final Map<TopicPartition, Integer> partitionLeaderCache;
+    private final PartitionLeaderCache partitionLeaderCache;
     private final AdminFetchMetricsManager adminFetchMetricsManager;
     private final Optional<ClientTelemetryReporter> clientTelemetryReporter;
 
@@ -631,7 +632,7 @@ public class KafkaAdminClient extends AdminClient {
             CommonClientConfigs.RETRY_BACKOFF_JITTER);
         this.clientTelemetryReporter = clientTelemetryReporter;
         this.metadataRecoveryStrategy = 
MetadataRecoveryStrategy.forName(config.getString(AdminClientConfig.METADATA_RECOVERY_STRATEGY_CONFIG));
-        this.partitionLeaderCache = new HashMap<>();
+        this.partitionLeaderCache = new PartitionLeaderCache();
         this.adminFetchMetricsManager = new AdminFetchMetricsManager(metrics);
         config.logUnused();
         AppInfoParser.registerAppInfo(JMX_PREFIX, clientId, metrics, 
time.milliseconds());
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/admin/internals/AbortTransactionHandler.java
 
b/clients/src/main/java/org/apache/kafka/clients/admin/internals/AbortTransactionHandler.java
index f0b6d28be6b..80ad8da3e3e 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/admin/internals/AbortTransactionHandler.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/admin/internals/AbortTransactionHandler.java
@@ -34,7 +34,6 @@ import org.apache.kafka.common.utils.LogContext;
 import org.slf4j.Logger;
 
 import java.util.List;
-import java.util.Map;
 import java.util.Set;
 
 import static java.util.Collections.singleton;
@@ -56,7 +55,7 @@ public class AbortTransactionHandler extends 
AdminApiHandler.Batched<TopicPartit
 
     public static PartitionLeaderStrategy.PartitionLeaderFuture<Void> 
newFuture(
         Set<TopicPartition> topicPartitions,
-        Map<TopicPartition, Integer> partitionLeaderCache
+        PartitionLeaderCache partitionLeaderCache
     ) {
         return new 
PartitionLeaderStrategy.PartitionLeaderFuture<>(topicPartitions, 
partitionLeaderCache);
     }
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/admin/internals/AdminApiDriver.java
 
b/clients/src/main/java/org/apache/kafka/clients/admin/internals/AdminApiDriver.java
index 6286f59ed71..2db63c7ed57 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/admin/internals/AdminApiDriver.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/admin/internals/AdminApiDriver.java
@@ -115,8 +115,13 @@ public class AdminApiDriver<K, V> {
         // metadata. For all cached keys, they can proceed straight to the 
fulfillment map.
         // Note that the cache is only used on the initial calls, and any 
errors that result
         // in additional lookups use the full set of lookup keys.
-        retryLookup(future.uncachedLookupKeys());
-        future.cachedKeyBrokerIdMapping().forEach((key, brokerId) -> 
fulfillmentMap.put(new FulfillmentScope(brokerId), key));
+        future.cachedKeyBrokerIdMapping().forEach((key, brokerId) -> {
+            if (AdminApiFuture.UNKNOWN_BROKER_ID.equals(brokerId)) {
+                unmap(key);
+            } else {
+                fulfillmentMap.put(new FulfillmentScope(brokerId), key);
+            }
+        });
     }
 
     /**
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/admin/internals/AdminApiFuture.java
 
b/clients/src/main/java/org/apache/kafka/clients/admin/internals/AdminApiFuture.java
index 322d116a3df..ed0b60b4304 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/admin/internals/AdminApiFuture.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/admin/internals/AdminApiFuture.java
@@ -19,7 +19,7 @@ package org.apache.kafka.clients.admin.internals;
 import org.apache.kafka.common.KafkaFuture;
 import org.apache.kafka.common.internals.KafkaFutureImpl;
 
-import java.util.Collections;
+import java.util.HashMap;
 import java.util.Map;
 import java.util.Set;
 import java.util.function.Function;
@@ -27,6 +27,8 @@ import java.util.stream.Collectors;
 
 public interface AdminApiFuture<K, V> {
 
+    Integer UNKNOWN_BROKER_ID = -1;
+
     /**
      * The initial set of lookup keys. Although this will usually match the 
fulfillment
      * keys, it does not necessarily have to. For example, in the case of
@@ -39,22 +41,17 @@ public interface AdminApiFuture<K, V> {
     Set<K> lookupKeys();
 
     /**
-     * The set of request keys that do not have cached key-broker id mappings. 
If there
-     * is no cached key mapping, this will be the same as the lookup keys.
-     * Can be empty, but only if the cached key mapping is not empty.
-     */
-    default Set<K> uncachedLookupKeys() {
-        return lookupKeys();
-    }
-
-    /**
-     * The cached key-broker id mapping. For lookup strategies that do not 
make use of a
-     * cache of metadata, this will be empty.
+     * The cached key-broker id mapping. For non-cached values(or lookup 
strategies that do not make use of a
+     * cache of metadata) the broker id will be {@link #UNKNOWN_BROKER_ID}
      *
      * @return mapping of keys to broker ids
      */
     default Map<K, Integer> cachedKeyBrokerIdMapping() {
-        return Collections.emptyMap();
+        Map<K, Integer> result = new HashMap<>();
+        for (K key : lookupKeys()) {
+            result.put(key, UNKNOWN_BROKER_ID);
+        }
+        return result;
     }
 
     /**
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/admin/internals/DeleteRecordsHandler.java
 
b/clients/src/main/java/org/apache/kafka/clients/admin/internals/DeleteRecordsHandler.java
index 4afef617cb2..7a8aca79b39 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/admin/internals/DeleteRecordsHandler.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/admin/internals/DeleteRecordsHandler.java
@@ -73,7 +73,7 @@ public final class DeleteRecordsHandler extends 
Batched<TopicPartition, DeletedR
 
     public static 
PartitionLeaderStrategy.PartitionLeaderFuture<DeletedRecords> newFuture(
             Collection<TopicPartition> topicPartitions,
-            Map<TopicPartition, Integer> partitionLeaderCache
+            PartitionLeaderCache partitionLeaderCache
     ) {
         return new PartitionLeaderStrategy.PartitionLeaderFuture<>(new 
HashSet<>(topicPartitions), partitionLeaderCache);
     }
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/admin/internals/DescribeProducersHandler.java
 
b/clients/src/main/java/org/apache/kafka/clients/admin/internals/DescribeProducersHandler.java
index 84338feb9e4..3ae5638423c 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/admin/internals/DescribeProducersHandler.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/admin/internals/DescribeProducersHandler.java
@@ -68,7 +68,7 @@ public class DescribeProducersHandler extends 
AdminApiHandler.Batched<TopicParti
 
     public static 
PartitionLeaderStrategy.PartitionLeaderFuture<PartitionProducerState> newFuture(
         Collection<TopicPartition> topicPartitions,
-        Map<TopicPartition, Integer> partitionLeaderCache
+        PartitionLeaderCache partitionLeaderCache
     ) {
         return new PartitionLeaderStrategy.PartitionLeaderFuture<>(new 
HashSet<>(topicPartitions), partitionLeaderCache);
     }
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/admin/internals/ListOffsetsHandler.java
 
b/clients/src/main/java/org/apache/kafka/clients/admin/internals/ListOffsetsHandler.java
index 330a9efaf9b..c03a6c5bee0 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/admin/internals/ListOffsetsHandler.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/admin/internals/ListOffsetsHandler.java
@@ -223,7 +223,7 @@ public final class ListOffsetsHandler extends 
Batched<TopicPartition, ListOffset
 
     public static 
PartitionLeaderStrategy.PartitionLeaderFuture<ListOffsetsResultInfo> newFuture(
         Collection<TopicPartition> topicPartitions,
-        Map<TopicPartition, Integer> partitionLeaderCache
+        PartitionLeaderCache partitionLeaderCache
     ) {
         return new PartitionLeaderStrategy.PartitionLeaderFuture<>(new 
HashSet<>(topicPartitions), partitionLeaderCache);
     }
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/admin/internals/PartitionLeaderCache.java
 
b/clients/src/main/java/org/apache/kafka/clients/admin/internals/PartitionLeaderCache.java
new file mode 100644
index 00000000000..089126dce46
--- /dev/null
+++ 
b/clients/src/main/java/org/apache/kafka/clients/admin/internals/PartitionLeaderCache.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.clients.admin.internals;
+
+import org.apache.kafka.common.TopicPartition;
+
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+
+public class PartitionLeaderCache {
+
+    private final Map<TopicPartition, Integer> cache = new HashMap<>();
+
+    public Map<TopicPartition, Integer> get(Collection<TopicPartition> keys) {
+        Map<TopicPartition, Integer> result = new HashMap<>();
+        synchronized (cache) {
+            for (TopicPartition key : keys) {
+                if (cache.containsKey(key)) {
+                    result.put(key, cache.get(key));
+                }
+            }
+        }
+        return result;
+    }
+
+    public void put(Map<TopicPartition, Integer> values) {
+        synchronized (cache) {
+            cache.putAll(values);
+        }
+    }
+
+    public void remove(Collection<TopicPartition> keys) {
+        synchronized (cache) {
+            for (TopicPartition key : keys) {
+                cache.remove(key);
+            }
+        }
+    }
+}
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/admin/internals/PartitionLeaderStrategy.java
 
b/clients/src/main/java/org/apache/kafka/clients/admin/internals/PartitionLeaderStrategy.java
index ff7dff2db8e..e43e7914a7f 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/admin/internals/PartitionLeaderStrategy.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/admin/internals/PartitionLeaderStrategy.java
@@ -33,7 +33,6 @@ import org.slf4j.Logger;
 
 import java.util.Collections;
 import java.util.HashMap;
-import java.util.HashSet;
 import java.util.Map;
 import java.util.Set;
 import java.util.function.Function;
@@ -208,10 +207,10 @@ public class PartitionLeaderStrategy implements 
AdminApiLookupStrategy<TopicPart
      */
     public static class PartitionLeaderFuture<V> implements 
AdminApiFuture<TopicPartition, V> {
         private final Set<TopicPartition> requestKeys;
-        private final Map<TopicPartition, Integer> partitionLeaderCache;
+        private final PartitionLeaderCache partitionLeaderCache;
         private final Map<TopicPartition, KafkaFuture<V>> futures;
 
-        public PartitionLeaderFuture(Set<TopicPartition> requestKeys, 
Map<TopicPartition, Integer> partitionLeaderCache) {
+        public PartitionLeaderFuture(Set<TopicPartition> requestKeys, 
PartitionLeaderCache partitionLeaderCache) {
             this.requestKeys = requestKeys;
             this.partitionLeaderCache = partitionLeaderCache;
             this.futures = 
requestKeys.stream().collect(Collectors.toUnmodifiableMap(
@@ -225,26 +224,12 @@ public class PartitionLeaderStrategy implements 
AdminApiLookupStrategy<TopicPart
             return futures.keySet();
         }
 
-        @Override
-        public Set<TopicPartition> uncachedLookupKeys() {
-            Set<TopicPartition> keys = new HashSet<>();
-            requestKeys.forEach(tp -> {
-                if (!partitionLeaderCache.containsKey(tp)) {
-                    keys.add(tp);
-                }
-            });
-            return keys;
-        }
-
         @Override
         public Map<TopicPartition, Integer> cachedKeyBrokerIdMapping() {
+            Map<TopicPartition, Integer> cache = 
partitionLeaderCache.get(requestKeys);
+
             Map<TopicPartition, Integer> mapping = new HashMap<>();
-            requestKeys.forEach(tp -> {
-                Integer brokerId = partitionLeaderCache.get(tp);
-                if (brokerId != null) {
-                    mapping.put(tp, brokerId);
-                }
-            });
+            requestKeys.forEach(tp -> mapping.put(tp, cache.getOrDefault(tp, 
UNKNOWN_BROKER_ID)));
             return mapping;
         }
 
@@ -263,16 +248,16 @@ public class PartitionLeaderStrategy implements 
AdminApiLookupStrategy<TopicPart
 
         @Override
         public void completeLookup(Map<TopicPartition, Integer> 
brokerIdMapping) {
-            partitionLeaderCache.putAll(brokerIdMapping);
+            partitionLeaderCache.put(brokerIdMapping);
         }
 
         @Override
         public void completeExceptionally(Map<TopicPartition, Throwable> 
errors) {
+            partitionLeaderCache.remove(errors.keySet());
             errors.forEach(this::completeExceptionally);
         }
 
         private void completeExceptionally(TopicPartition key, Throwable t) {
-            partitionLeaderCache.remove(key);
             futureOrThrow(key).completeExceptionally(t);
         }
 
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/admin/internals/PartitionLeaderStrategyIntegrationTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/admin/internals/PartitionLeaderStrategyIntegrationTest.java
index 778502505fb..dcbc78a02f9 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/admin/internals/PartitionLeaderStrategyIntegrationTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/admin/internals/PartitionLeaderStrategyIntegrationTest.java
@@ -76,7 +76,7 @@ public class PartitionLeaderStrategyIntegrationTest {
 
     @Test
     public void testCachingRepeatedRequest() {
-        Map<TopicPartition, Integer> partitionLeaderCache = new HashMap<>();
+        PartitionLeaderCache partitionLeaderCache = new PartitionLeaderCache();
 
         TopicPartition tp0 = new TopicPartition("T", 0);
         TopicPartition tp1 = new TopicPartition("T", 1);
@@ -99,8 +99,9 @@ public class PartitionLeaderStrategyIntegrationTest {
         assertFalse(result.all().get(tp0).isDone());
         assertFalse(result.all().get(tp1).isDone());
 
-        assertEquals(1, partitionLeaderCache.get(tp0));
-        assertEquals(2, partitionLeaderCache.get(tp1));
+        Map<TopicPartition, Integer> cache = 
partitionLeaderCache.get(Set.of(tp0, tp1));
+        assertEquals(1, cache.get(tp0));
+        assertEquals(2, cache.get(tp1));
 
         // Second, the fulfillment stage makes the actual requests
         requestSpecs = driver.poll();
@@ -139,7 +140,7 @@ public class PartitionLeaderStrategyIntegrationTest {
         // 2) for T-1 and T-2            (leadership data for T-1 should be 
cached from previous request)
         // 3) for T-0, T-1 and T-2       (all leadership data should be cached 
already)
         // 4) for T-0, T-1, T-2 and T-3  (just T-3 needs to be looked up)
-        Map<TopicPartition, Integer> partitionLeaderCache = new HashMap<>();
+        PartitionLeaderCache partitionLeaderCache = new PartitionLeaderCache();
 
         TopicPartition tp0 = new TopicPartition("T", 0);
         TopicPartition tp1 = new TopicPartition("T", 1);
@@ -168,8 +169,9 @@ public class PartitionLeaderStrategyIntegrationTest {
         assertFalse(result.all().get(tp0).isDone());
         assertFalse(result.all().get(tp1).isDone());
 
-        assertEquals(1, partitionLeaderCache.get(tp0));
-        assertEquals(2, partitionLeaderCache.get(tp1));
+        Map<TopicPartition, Integer> cache = 
partitionLeaderCache.get(Set.of(tp0, tp1));
+        assertEquals(1, cache.get(tp0));
+        assertEquals(2, cache.get(tp1));
 
         // Second, the fulfillment stage makes the actual requests
         requestSpecs = driver.poll();
@@ -206,9 +208,10 @@ public class PartitionLeaderStrategyIntegrationTest {
         assertTrue(result.all().get(tp1).isDone());  // Already fulfilled
         assertFalse(result.all().get(tp2).isDone());
 
-        assertEquals(1, partitionLeaderCache.get(tp0));
-        assertEquals(2, partitionLeaderCache.get(tp1));
-        assertEquals(1, partitionLeaderCache.get(tp2));
+        cache = partitionLeaderCache.get(Set.of(tp0, tp1, tp2));
+        assertEquals(1, cache.get(tp0));
+        assertEquals(2, cache.get(tp1));
+        assertEquals(1, cache.get(tp2));
 
         // Finally, the fulfillment stage makes the actual request for the 
uncached topic-partition
         requestSpecs = driver.poll();
@@ -268,10 +271,11 @@ public class PartitionLeaderStrategyIntegrationTest {
         assertTrue(result.all().get(tp2).isDone());  // Already fulfilled
         assertFalse(result.all().get(tp3).isDone());
 
-        assertEquals(1, partitionLeaderCache.get(tp0));
-        assertEquals(2, partitionLeaderCache.get(tp1));
-        assertEquals(1, partitionLeaderCache.get(tp2));
-        assertEquals(2, partitionLeaderCache.get(tp3));
+        cache = partitionLeaderCache.get(Set.of(tp0, tp1, tp2, tp3));
+        assertEquals(1, cache.get(tp0));
+        assertEquals(2, cache.get(tp1));
+        assertEquals(1, cache.get(tp2));
+        assertEquals(2, cache.get(tp3));
 
         // Finally, the fulfillment stage makes the actual request for the 
uncached topic-partition
         requestSpecs = driver.poll();
@@ -288,7 +292,7 @@ public class PartitionLeaderStrategyIntegrationTest {
 
     @Test
     public void testNotLeaderFulfillmentError() {
-        Map<TopicPartition, Integer> partitionLeaderCache = new HashMap<>();
+        PartitionLeaderCache partitionLeaderCache = new PartitionLeaderCache();
 
         TopicPartition tp0 = new TopicPartition("T", 0);
         TopicPartition tp1 = new TopicPartition("T", 1);
@@ -311,8 +315,9 @@ public class PartitionLeaderStrategyIntegrationTest {
         assertFalse(result.all().get(tp0).isDone());
         assertFalse(result.all().get(tp1).isDone());
 
-        assertEquals(1, partitionLeaderCache.get(tp0));
-        assertEquals(2, partitionLeaderCache.get(tp1));
+        Map<TopicPartition, Integer> cache = 
partitionLeaderCache.get(Set.of(tp0, tp1));
+        assertEquals(1, cache.get(tp0));
+        assertEquals(2, cache.get(tp1));
 
         // Second, the fulfillment stage makes the actual requests
         requestSpecs = driver.poll();
@@ -337,8 +342,9 @@ public class PartitionLeaderStrategyIntegrationTest {
         assertTrue(result.all().get(tp0).isDone());
         assertFalse(result.all().get(tp1).isDone());
 
-        assertEquals(1, partitionLeaderCache.get(tp0));
-        assertEquals(1, partitionLeaderCache.get(tp1));
+        cache = partitionLeaderCache.get(Set.of(tp0, tp1));
+        assertEquals(1, cache.get(tp0));
+        assertEquals(1, cache.get(tp1));
 
         // And the fulfillment stage makes the actual request
         requestSpecs = driver.poll();
@@ -354,7 +360,7 @@ public class PartitionLeaderStrategyIntegrationTest {
     @Test
     public void testFatalLookupError() {
         TopicPartition tp0 = new TopicPartition("T", 0);
-        Map<TopicPartition, Integer> partitionLeaderCache = new HashMap<>();
+        PartitionLeaderCache partitionLeaderCache = new PartitionLeaderCache();
         PartitionLeaderStrategy.PartitionLeaderFuture<Void> result =
             new 
PartitionLeaderStrategy.PartitionLeaderFuture<>(Collections.singleton(tp0), 
partitionLeaderCache);
         AdminApiDriver<TopicPartition, Void> driver = buildDriver(result);
@@ -374,7 +380,7 @@ public class PartitionLeaderStrategyIntegrationTest {
     @Test
     public void testRetryLookupAfterDisconnect() {
         TopicPartition tp0 = new TopicPartition("T", 0);
-        Map<TopicPartition, Integer> partitionLeaderCache = new HashMap<>();
+        PartitionLeaderCache partitionLeaderCache = new PartitionLeaderCache();
         PartitionLeaderStrategy.PartitionLeaderFuture<Void> result =
             new 
PartitionLeaderStrategy.PartitionLeaderFuture<>(Collections.singleton(tp0), 
partitionLeaderCache);
         AdminApiDriver<TopicPartition, Void> driver = buildDriver(result);

Reply via email to