This is an automated email from the ASF dual-hosted git repository.

rsivaram pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 98d84b17f74 KAFKA-14451: Rack-aware consumer partition assignment for 
RangeAssignor (KIP-881) (#12990)
98d84b17f74 is described below

commit 98d84b17f74b0bfe65163d0ddf88976746de5f7e
Author: Rajini Sivaram <[email protected]>
AuthorDate: Wed Mar 1 21:01:35 2023 +0000

    KAFKA-14451: Rack-aware consumer partition assignment for RangeAssignor 
(KIP-881) (#12990)
    
    Best-effort rack alignment for range assignor when both consumer racks and 
partition racks are available with the protocol changes introduced in KIP-881. 
Rack-aware assignment is enabled by configuring client.rack for consumers. 
Balanced assignment per topic is prioritized over rack-alignment. For topics 
with equal partitions and the same set of subscribers, co-partitioning is 
prioritized over rack-alignment.
    
    Reviewers: David Jacot <[email protected]>
---
 .../kafka/clients/consumer/RangeAssignor.java      | 229 +++++++++--
 .../internals/AbstractPartitionAssignor.java       |  66 +++-
 .../kafka/clients/consumer/internals/Utils.java    |   4 +-
 .../kafka/clients/consumer/RangeAssignorTest.java  | 436 ++++++++++++++++-----
 .../clients/consumer/RoundRobinAssignorTest.java   |  11 +-
 .../internals/AbstractPartitionAssignorTest.java   | 169 ++++++++
 .../kafka/api/PlaintextConsumerTest.scala          |  36 +-
 .../server/FetchFromFollowerIntegrationTest.scala  |  66 +++-
 8 files changed, 852 insertions(+), 165 deletions(-)

diff --git 
a/clients/src/main/java/org/apache/kafka/clients/consumer/RangeAssignor.java 
b/clients/src/main/java/org/apache/kafka/clients/consumer/RangeAssignor.java
index aec0d3997c4..0b3071a4915 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/RangeAssignor.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/RangeAssignor.java
@@ -17,13 +17,27 @@
 package org.apache.kafka.clients.consumer;
 
 import org.apache.kafka.clients.consumer.internals.AbstractPartitionAssignor;
+import 
org.apache.kafka.clients.consumer.internals.Utils.TopicPartitionComparator;
+import org.apache.kafka.common.Node;
+import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.TopicPartition;
 
 import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
+import java.util.function.BiFunction;
+import java.util.function.Function;
+import java.util.stream.Collectors;
 
 /**
  * <p>The range assignor works on a per-topic basis. For each topic, we lay 
out the available partitions in numeric order
@@ -63,9 +77,26 @@ import java.util.Map;
  * <li><code>I0: [t0p0, t0p1, t1p0, t1p1]</code>
  * <li><code>I1: [t0p2, t1p2]</code>
  * </ul>
+ * <p>
+ * Rack-aware assignment is used if both consumer and partition replica racks 
are available and
+ * some partitions have replicas only on a subset of racks. We attempt to 
match consumer racks with
+ * partition replica racks on a best-effort basis, prioritizing balanced 
assignment over rack alignment.
+ * Topics with equal partition count and same set of subscribers guarantee 
co-partitioning by prioritizing
+ * co-partitioning over rack alignment. In this case, aligning partition 
replicas of these topics on the
+ * same racks will improve locality for consumers. For example, if partitions 
0 of all topics have a replica
+ * on rack 'a', partition 1 on rack 'b' etc., partition 0 of all topics can be 
assigned to a consumer
+ * on rack 'a', partition 1 to a consumer on rack 'b' and so on.
+ * <p>
+ * Note that rack-aware assignment currently takes all replicas into account, 
including any offline replicas
+ * and replicas that are not in the ISR. This is based on the assumption that 
these replicas are likely
+ * to join the ISR relatively soon. Since consumers don't rebalance on ISR 
change, this avoids unnecessary
+ * cross-rack traffic for long durations after replicas rejoin the ISR. In the 
future, we may consider
+ * rebalancing when replicas are added or removed to improve consumer rack 
alignment.
+ * </p>
  */
 public class RangeAssignor extends AbstractPartitionAssignor {
     public static final String RANGE_ASSIGNOR_NAME = "range";
+    private static final TopicPartitionComparator PARTITION_COMPARATOR = new 
TopicPartitionComparator();
 
     @Override
     public String name() {
@@ -74,45 +105,193 @@ public class RangeAssignor extends 
AbstractPartitionAssignor {
 
     private Map<String, List<MemberInfo>> consumersPerTopic(Map<String, 
Subscription> consumerMetadata) {
         Map<String, List<MemberInfo>> topicToConsumers = new HashMap<>();
-        for (Map.Entry<String, Subscription> subscriptionEntry : 
consumerMetadata.entrySet()) {
-            String consumerId = subscriptionEntry.getKey();
-            MemberInfo memberInfo = new MemberInfo(consumerId, 
subscriptionEntry.getValue().groupInstanceId());
-            for (String topic : subscriptionEntry.getValue().topics()) {
-                put(topicToConsumers, topic, memberInfo);
-            }
-        }
+        consumerMetadata.forEach((consumerId, subscription) -> {
+            MemberInfo memberInfo = new MemberInfo(consumerId, 
subscription.groupInstanceId(), subscription.rackId());
+            subscription.topics().forEach(topic -> put(topicToConsumers, 
topic, memberInfo));
+        });
         return topicToConsumers;
     }
 
+    /**
+     * Performs range assignment of the specified partitions for the consumers 
with the provided subscriptions.
+     * If rack-awareness is enabled for one or more consumers, we perform 
rack-aware assignment first to assign
+     * the subset of partitions that can be aligned on racks, while retaining 
the same co-partitioning and
+     * per-topic balancing guarantees as non-rack-aware range assignment. The 
remaining partitions are assigned
+     * using standard non-rack-aware range assignment logic, which may result 
in mis-aligned racks.
+     */
     @Override
-    public Map<String, List<TopicPartition>> assign(Map<String, Integer> 
partitionsPerTopic,
-                                                    Map<String, Subscription> 
subscriptions) {
+    public Map<String, List<TopicPartition>> assignPartitions(Map<String, 
List<PartitionInfo>> partitionsPerTopic,
+                                                              Map<String, 
Subscription> subscriptions) {
         Map<String, List<MemberInfo>> consumersPerTopic = 
consumersPerTopic(subscriptions);
+        Map<String, String> consumerRacks = consumerRacks(subscriptions);
+        List<TopicAssignmentState> topicAssignmentStates = 
partitionsPerTopic.entrySet().stream()
+                .filter(e -> !e.getValue().isEmpty())
+                .map(e -> new TopicAssignmentState(e.getKey(), e.getValue(), 
consumersPerTopic.get(e.getKey()), consumerRacks))
+                .collect(Collectors.toList());
 
         Map<String, List<TopicPartition>> assignment = new HashMap<>();
-        for (String memberId : subscriptions.keySet())
-            assignment.put(memberId, new ArrayList<>());
+        subscriptions.keySet().forEach(memberId -> assignment.put(memberId, 
new ArrayList<>()));
+
+        boolean useRackAware = topicAssignmentStates.stream().anyMatch(t -> 
t.needsRackAwareAssignment);
+        if (useRackAware)
+            assignWithRackMatching(topicAssignmentStates, assignment);
 
-        for (Map.Entry<String, List<MemberInfo>> topicEntry : 
consumersPerTopic.entrySet()) {
-            String topic = topicEntry.getKey();
-            List<MemberInfo> consumersForTopic = topicEntry.getValue();
+        topicAssignmentStates.forEach(t -> assignRanges(t, (c, tp) -> true, 
assignment));
 
-            Integer numPartitionsForTopic = partitionsPerTopic.get(topic);
-            if (numPartitionsForTopic == null)
+        if (useRackAware)
+            assignment.values().forEach(list -> 
list.sort(PARTITION_COMPARATOR));
+        return assignment;
+    }
+
+    // This method is not used, but retained for compatibility with any custom 
assignors that extend this class.
+    @Override
+    public Map<String, List<TopicPartition>> assign(Map<String, Integer> 
partitionsPerTopic,
+                                                    Map<String, Subscription> 
subscriptions) {
+        return 
assignPartitions(partitionInfosWithoutRacks(partitionsPerTopic), subscriptions);
+    }
+
+    private void assignRanges(TopicAssignmentState assignmentState,
+                              BiFunction<String, TopicPartition, Boolean> 
mayAssign,
+                              Map<String, List<TopicPartition>> assignment) {
+        for (String consumer : assignmentState.consumers.keySet()) {
+            if (assignmentState.unassignedPartitions.isEmpty())
+                break;
+            List<TopicPartition> assignablePartitions = 
assignmentState.unassignedPartitions.stream()
+                    .filter(tp -> mayAssign.apply(consumer, tp))
+                    .limit(assignmentState.maxAssignable(consumer))
+                    .collect(Collectors.toList());
+            if (assignablePartitions.isEmpty())
                 continue;
 
-            Collections.sort(consumersForTopic);
+            assign(consumer, assignablePartitions, assignmentState, 
assignment);
+        }
+    }
+
+    private void assignWithRackMatching(Collection<TopicAssignmentState> 
assignmentStates,
+                                        Map<String, List<TopicPartition>> 
assignment) {
 
-            int numPartitionsPerConsumer = numPartitionsForTopic / 
consumersForTopic.size();
-            int consumersWithExtraPartition = numPartitionsForTopic % 
consumersForTopic.size();
+        assignmentStates.stream().collect(Collectors.groupingBy(t -> 
t.consumers)).forEach((consumers, states) -> {
+            states.stream().collect(Collectors.groupingBy(t -> 
t.partitionRacks.size())).forEach((numPartitions, coPartitionedStates) -> {
+                if (coPartitionedStates.size() > 1)
+                    assignCoPartitionedWithRackMatching(consumers, 
numPartitions, states, assignment);
+                else {
+                    TopicAssignmentState state = coPartitionedStates.get(0);
+                    if (state.needsRackAwareAssignment)
+                        assignRanges(state, state::racksMatch, assignment);
+                }
+            });
+        });
+    }
+
+    private void assignCoPartitionedWithRackMatching(LinkedHashMap<String, 
Optional<String>> consumers,
+                                                     int numPartitions,
+                                                     
Collection<TopicAssignmentState> assignmentStates,
+                                                     Map<String, 
List<TopicPartition>> assignment) {
+
+        Set<String> remainingConsumers = new 
LinkedHashSet<>(consumers.keySet());
+        for (int i = 0; i < numPartitions; i++) {
+            int p = i;
 
-            List<TopicPartition> partitions = 
AbstractPartitionAssignor.partitions(topic, numPartitionsForTopic);
-            for (int i = 0, n = consumersForTopic.size(); i < n; i++) {
-                int start = numPartitionsPerConsumer * i + Math.min(i, 
consumersWithExtraPartition);
-                int length = numPartitionsPerConsumer + (i + 1 > 
consumersWithExtraPartition ? 0 : 1);
-                
assignment.get(consumersForTopic.get(i).memberId).addAll(partitions.subList(start,
 start + length));
+            Optional<String> matchingConsumer = remainingConsumers.stream()
+                    .filter(c -> assignmentStates.stream().allMatch(t -> 
t.racksMatch(c, new TopicPartition(t.topic, p)) && t.maxAssignable(c) > 0))
+                    .findFirst();
+            if (matchingConsumer.isPresent()) {
+                String consumer = matchingConsumer.get();
+                assignmentStates.forEach(t -> assign(consumer, 
Collections.singletonList(new TopicPartition(t.topic, p)), t, assignment));
+
+                if (assignmentStates.stream().noneMatch(t -> 
t.maxAssignable(consumer) > 0)) {
+                    remainingConsumers.remove(consumer);
+                    if (remainingConsumers.isEmpty())
+                        break;
+                }
             }
         }
-        return assignment;
+    }
+
+    private void assign(String consumer, List<TopicPartition> partitions, 
TopicAssignmentState assignmentState, Map<String, List<TopicPartition>> 
assignment) {
+        assignment.get(consumer).addAll(partitions);
+        assignmentState.onAssigned(consumer, partitions);
+    }
+
+    private Map<String, String> consumerRacks(Map<String, Subscription> 
subscriptions) {
+        Map<String, String> consumerRacks = new 
HashMap<>(subscriptions.size());
+        subscriptions.forEach((memberId, subscription) ->
+                subscription.rackId().filter(r -> 
!r.isEmpty()).ifPresent(rackId -> consumerRacks.put(memberId, rackId)));
+        return consumerRacks;
+    }
+
+    private class TopicAssignmentState {
+        private final String topic;
+        private final LinkedHashMap<String, Optional<String>> consumers;
+        private final boolean needsRackAwareAssignment;
+        private final Map<TopicPartition, Set<String>> partitionRacks;
+
+        private final Set<TopicPartition> unassignedPartitions;
+        private final Map<String, Integer> numAssignedByConsumer;
+        private final int numPartitionsPerConsumer;
+        private int remainingConsumersWithExtraPartition;
+
+        public TopicAssignmentState(String topic, List<PartitionInfo> 
partitionInfos, List<MemberInfo> membersOrNull, Map<String, String> 
consumerRacks) {
+            this.topic = topic;
+            List<MemberInfo> members = membersOrNull == null ? 
Collections.emptyList() : membersOrNull;
+            Collections.sort(members);
+            consumers = members.stream().map(c -> c.memberId)
+                    .collect(Collectors.toMap(Function.identity(), c -> 
Optional.ofNullable(consumerRacks.get(c)), (a, b) -> a, LinkedHashMap::new));
+
+            this.unassignedPartitions = partitionInfos.stream().map(p -> new 
TopicPartition(p.topic(), p.partition()))
+                    .collect(Collectors.toCollection(LinkedHashSet::new));
+            this.numAssignedByConsumer = 
consumers.keySet().stream().collect(Collectors.toMap(Function.identity(), c -> 
0));
+            numPartitionsPerConsumer = consumers.isEmpty() ? 0 : 
partitionInfos.size() / consumers.size();
+            remainingConsumersWithExtraPartition = consumers.isEmpty() ? 0 : 
partitionInfos.size() % consumers.size();
+
+            Set<String> allConsumerRacks = new HashSet<>();
+            Set<String> allPartitionRacks = new HashSet<>();
+            members.stream().map(m -> 
m.memberId).filter(consumerRacks::containsKey)
+                    .forEach(memberId -> 
allConsumerRacks.add(consumerRacks.get(memberId)));
+            if (!allConsumerRacks.isEmpty()) {
+                partitionRacks = new HashMap<>(partitionInfos.size());
+                partitionInfos.forEach(p -> {
+                    TopicPartition tp = new TopicPartition(p.topic(), 
p.partition());
+                    Set<String> racks = Arrays.stream(p.replicas())
+                            .map(Node::rack)
+                            .filter(Objects::nonNull)
+                            .collect(Collectors.toSet());
+                    partitionRacks.put(tp, racks);
+                    allPartitionRacks.addAll(racks);
+                });
+            } else {
+                partitionRacks = Collections.emptyMap();
+            }
+
+            needsRackAwareAssignment = 
useRackAwareAssignment(allConsumerRacks, allPartitionRacks, partitionRacks);
+        }
+
+        boolean racksMatch(String consumer, TopicPartition tp) {
+            Optional<String> consumerRack = consumers.get(consumer);
+            Set<String> replicaRacks = partitionRacks.get(tp);
+            return !consumerRack.isPresent() || (replicaRacks != null && 
replicaRacks.contains(consumerRack.get()));
+        }
+
+        int maxAssignable(String consumer) {
+            int maxForConsumer = numPartitionsPerConsumer + 
(remainingConsumersWithExtraPartition > 0 ? 1 : 0) - 
numAssignedByConsumer.get(consumer);
+            return Math.max(0, maxForConsumer);
+        }
+
+        void onAssigned(String consumer, List<TopicPartition> 
newlyAssignedPartitions) {
+            int numAssigned = numAssignedByConsumer.compute(consumer, (c, n) 
-> n + newlyAssignedPartitions.size());
+            if (numAssigned > numPartitionsPerConsumer)
+                remainingConsumersWithExtraPartition--;
+            unassignedPartitions.removeAll(newlyAssignedPartitions);
+        }
+
+        @Override
+        public String toString() {
+            return "TopicAssignmentState(" +
+                    "topic=" + topic +
+                    ", consumers=" + consumers +
+                    ", partitionRacks=" + partitionRacks +
+                    ", unassignedPartitions=" + unassignedPartitions +
+                    ")";
+        }
     }
 }
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractPartitionAssignor.java
 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractPartitionAssignor.java
index ed0282b4062..f6beb477f15 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractPartitionAssignor.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractPartitionAssignor.java
@@ -18,17 +18,23 @@ package org.apache.kafka.clients.consumer.internals;
 
 import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor;
 import org.apache.kafka.common.Cluster;
+import org.apache.kafka.common.Node;
+import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.TopicPartition;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Map.Entry;
 import java.util.Optional;
 import java.util.Set;
+import java.util.stream.Collectors;
 
 /**
  * Abstract assignor implementation which does some common grunt work (in 
particular collecting
@@ -36,6 +42,10 @@ import java.util.Set;
  */
 public abstract class AbstractPartitionAssignor implements 
ConsumerPartitionAssignor {
     private static final Logger log = 
LoggerFactory.getLogger(AbstractPartitionAssignor.class);
+    private static final Node[] NO_NODES = new Node[] {Node.noNode()};
+
+    // Used only in unit tests to verify rack-aware assignment when all racks 
have all partitions.
+    boolean preferRackAwareLogic;
 
     /**
      * Perform the group assignment given the partition counts and member 
subscriptions
@@ -47,6 +57,18 @@ public abstract class AbstractPartitionAssignor implements 
ConsumerPartitionAssi
     public abstract Map<String, List<TopicPartition>> assign(Map<String, 
Integer> partitionsPerTopic,
                                                              Map<String, 
Subscription> subscriptions);
 
+    /**
+     * Default implementation of assignPartitions() that does not include 
racks. This is only
+     * included to avoid breaking any custom implementation that extends 
AbstractPartitionAssignor.
+     * Note that this class is internal, but to be safe, we are maintaining 
compatibility.
+     */
+    public Map<String, List<TopicPartition>> assignPartitions(Map<String, 
List<PartitionInfo>> partitionsPerTopic,
+            Map<String, Subscription> subscriptions) {
+        Map<String, Integer> partitionCountPerTopic = 
partitionsPerTopic.entrySet().stream()
+                .collect(Collectors.toMap(Entry::getKey, e -> 
e.getValue().size()));
+        return assign(partitionCountPerTopic, subscriptions);
+    }
+
     @Override
     public GroupAssignment assign(Cluster metadata, GroupSubscription 
groupSubscription) {
         Map<String, Subscription> subscriptions = 
groupSubscription.groupSubscription();
@@ -54,16 +76,19 @@ public abstract class AbstractPartitionAssignor implements 
ConsumerPartitionAssi
         for (Map.Entry<String, Subscription> subscriptionEntry : 
subscriptions.entrySet())
             allSubscribedTopics.addAll(subscriptionEntry.getValue().topics());
 
-        Map<String, Integer> partitionsPerTopic = new HashMap<>();
+        Map<String, List<PartitionInfo>> partitionsPerTopic = new HashMap<>();
         for (String topic : allSubscribedTopics) {
-            Integer numPartitions = metadata.partitionCountForTopic(topic);
-            if (numPartitions != null && numPartitions > 0)
-                partitionsPerTopic.put(topic, numPartitions);
-            else
+            List<PartitionInfo> partitions = 
metadata.partitionsForTopic(topic);
+            if (partitions != null && !partitions.isEmpty()) {
+                partitions = new ArrayList<>(partitions);
+                
partitions.sort(Comparator.comparingInt(PartitionInfo::partition));
+                partitionsPerTopic.put(topic, partitions);
+            } else {
                 log.debug("Skipping assignment for topic {} since no metadata 
is available", topic);
+            }
         }
 
-        Map<String, List<TopicPartition>> rawAssignments = 
assign(partitionsPerTopic, subscriptions);
+        Map<String, List<TopicPartition>> rawAssignments = 
assignPartitions(partitionsPerTopic, subscriptions);
 
         // this class maintains no user data, so just wrap the results
         Map<String, Assignment> assignments = new HashMap<>();
@@ -84,13 +109,40 @@ public abstract class AbstractPartitionAssignor implements 
ConsumerPartitionAssi
         return partitions;
     }
 
+    protected static Map<String, List<PartitionInfo>> 
partitionInfosWithoutRacks(Map<String, Integer> partitionsPerTopic) {
+        return 
partitionsPerTopic.entrySet().stream().collect(Collectors.toMap(Entry::getKey, 
e -> {
+            String topic = e.getKey();
+            int numPartitions = e.getValue();
+            List<PartitionInfo> partitionInfos = new 
ArrayList<>(numPartitions);
+            for (int i = 0; i < numPartitions; i++)
+                partitionInfos.add(new PartitionInfo(topic, i, Node.noNode(), 
NO_NODES, NO_NODES));
+            return partitionInfos;
+        }));
+    }
+
+    protected boolean useRackAwareAssignment(Set<String> consumerRacks, 
Set<String> partitionRacks, Map<TopicPartition, Set<String>> racksPerPartition) 
{
+        if (consumerRacks.isEmpty() || Collections.disjoint(consumerRacks, 
partitionRacks))
+            return false;
+        else if (preferRackAwareLogic)
+            return true;
+        else {
+            return 
!racksPerPartition.values().stream().allMatch(partitionRacks::equals);
+        }
+    }
+
     public static class MemberInfo implements Comparable<MemberInfo> {
         public final String memberId;
         public final Optional<String> groupInstanceId;
+        public final Optional<String> rackId;
 
-        public MemberInfo(String memberId, Optional<String> groupInstanceId) {
+        public MemberInfo(String memberId, Optional<String> groupInstanceId, 
Optional<String> rackId) {
             this.memberId = memberId;
             this.groupInstanceId = groupInstanceId;
+            this.rackId = rackId;
+        }
+
+        public MemberInfo(String memberId, Optional<String> groupInstanceId) {
+            this(memberId, groupInstanceId, Optional.empty());
         }
 
         @Override
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Utils.java 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Utils.java
index c86edf3c98c..acad4730393 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Utils.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Utils.java
@@ -22,7 +22,7 @@ import java.util.List;
 import java.util.Map;
 import org.apache.kafka.common.TopicPartition;
 
-final class Utils {
+public final class Utils {
 
     final static class PartitionComparator implements 
Comparator<TopicPartition>, Serializable {
         private static final long serialVersionUID = 1L;
@@ -44,7 +44,7 @@ final class Utils {
         }
     }
 
-    final static class TopicPartitionComparator implements 
Comparator<TopicPartition>, Serializable {
+    public final static class TopicPartitionComparator implements 
Comparator<TopicPartition>, Serializable {
         private static final long serialVersionUID = 1L;
 
         @Override
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/consumer/RangeAssignorTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/consumer/RangeAssignorTest.java
index e067e6fdaad..fa60f672388 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/consumer/RangeAssignorTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/consumer/RangeAssignorTest.java
@@ -19,9 +19,15 @@ package org.apache.kafka.clients.consumer;
 import 
org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Subscription;
 import org.apache.kafka.clients.consumer.internals.AbstractPartitionAssignor;
 import 
org.apache.kafka.clients.consumer.internals.AbstractPartitionAssignor.MemberInfo;
+import 
org.apache.kafka.clients.consumer.internals.AbstractPartitionAssignorTest;
+import 
org.apache.kafka.clients.consumer.internals.AbstractPartitionAssignorTest.RackConfig;
+import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.TopicPartition;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.EnumSource;
+import org.junit.jupiter.params.provider.ValueSource;
 
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -32,16 +38,25 @@ import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 
+import static java.util.Arrays.asList;
+import static 
org.apache.kafka.clients.consumer.internals.AbstractPartitionAssignorTest.ALL_RACKS;
+import static 
org.apache.kafka.clients.consumer.internals.AbstractPartitionAssignorTest.TEST_NAME_WITH_CONSUMER_RACK;
+import static 
org.apache.kafka.clients.consumer.internals.AbstractPartitionAssignorTest.TEST_NAME_WITH_RACK_CONFIG;
+import static 
org.apache.kafka.clients.consumer.internals.AbstractPartitionAssignorTest.racks;
+import static 
org.apache.kafka.clients.consumer.internals.AbstractPartitionAssignorTest.nullRacks;
+import static 
org.apache.kafka.clients.consumer.internals.AbstractPartitionAssignorTest.verifyRackAssignment;
+import static org.apache.kafka.common.utils.Utils.mkEntry;
+import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
 public class RangeAssignorTest {
 
-    private RangeAssignor assignor = new RangeAssignor();
+    private final RangeAssignor assignor = new RangeAssignor();
 
     // For plural tests
-    private String topic1 = "topic1";
-    private String topic2 = "topic2";
+    private final String topic1 = "topic1";
+    private final String topic2 = "topic2";
     private final String consumer1 = "consumer1";
     private final String instance1 = "instance1";
     private final String consumer2 = "consumer2";
@@ -49,7 +64,11 @@ public class RangeAssignorTest {
     private final String consumer3 = "consumer3";
     private final String instance3 = "instance3";
 
+    private int numBrokerRacks;
+    private boolean hasConsumerRack;
+
     private List<MemberInfo> staticMemberInfos;
+    private int replicationFactor = 3;
 
     @BeforeEach
     public void setUp() {
@@ -59,188 +78,205 @@ public class RangeAssignorTest {
         staticMemberInfos.add(new MemberInfo(consumer3, 
Optional.of(instance3)));
     }
 
-    @Test
-    public void testOneConsumerNoTopic() {
-        Map<String, Integer> partitionsPerTopic = new HashMap<>();
+    @ParameterizedTest(name = TEST_NAME_WITH_CONSUMER_RACK)
+    @ValueSource(booleans = {true, false})
+    public void testOneConsumerNoTopic(boolean hasConsumerRack) {
+        initializeRacks(hasConsumerRack ? RackConfig.BROKER_AND_CONSUMER_RACK 
: RackConfig.NO_CONSUMER_RACK);
+        Map<String, List<PartitionInfo>> partitionsPerTopic = new HashMap<>();
 
-        Map<String, List<TopicPartition>> assignment = 
assignor.assign(partitionsPerTopic,
-                Collections.singletonMap(consumer1, new 
Subscription(Collections.emptyList())));
+        Map<String, List<TopicPartition>> assignment = 
assignor.assignPartitions(partitionsPerTopic,
+                Collections.singletonMap(consumer1, 
subscription(Collections.emptyList(), 0)));
 
         assertEquals(Collections.singleton(consumer1), assignment.keySet());
         assertTrue(assignment.get(consumer1).isEmpty());
     }
 
-    @Test
-    public void testOneConsumerNonexistentTopic() {
-        Map<String, Integer> partitionsPerTopic = new HashMap<>();
-        Map<String, List<TopicPartition>> assignment = 
assignor.assign(partitionsPerTopic,
-                Collections.singletonMap(consumer1, new 
Subscription(topics(topic1))));
+    @ParameterizedTest(name = TEST_NAME_WITH_CONSUMER_RACK)
+    @ValueSource(booleans = {true, false})
+    public void testOneConsumerNonexistentTopic(boolean hasConsumerRack) {
+        initializeRacks(hasConsumerRack ? RackConfig.BROKER_AND_CONSUMER_RACK 
: RackConfig.NO_CONSUMER_RACK);
+        Map<String, List<PartitionInfo>> partitionsPerTopic = new HashMap<>();
+        Map<String, List<TopicPartition>> assignment = 
assignor.assignPartitions(partitionsPerTopic,
+                Collections.singletonMap(consumer1, 
subscription(topics(topic1), 0)));
         assertEquals(Collections.singleton(consumer1), assignment.keySet());
         assertTrue(assignment.get(consumer1).isEmpty());
     }
 
-    @Test
-    public void testOneConsumerOneTopic() {
-        Map<String, Integer> partitionsPerTopic = new HashMap<>();
-        partitionsPerTopic.put(topic1, 3);
+    @ParameterizedTest(name = "rackConfig = {0}")
+    @EnumSource(RackConfig.class)
+    public void testOneConsumerOneTopic(RackConfig rackConfig) {
+        initializeRacks(rackConfig);
+        Map<String, List<PartitionInfo>> partitionsPerTopic = new HashMap<>();
+        partitionsPerTopic.put(topic1, partitionInfos(topic1, 3));
 
-        Map<String, List<TopicPartition>> assignment = 
assignor.assign(partitionsPerTopic,
-                Collections.singletonMap(consumer1, new 
Subscription(topics(topic1))));
+        Map<String, List<TopicPartition>> assignment = 
assignor.assignPartitions(partitionsPerTopic,
+                Collections.singletonMap(consumer1, 
subscription(topics(topic1), 0)));
 
         assertEquals(Collections.singleton(consumer1), assignment.keySet());
         assertAssignment(partitions(tp(topic1, 0), tp(topic1, 1), tp(topic1, 
2)), assignment.get(consumer1));
     }
 
-    @Test
-    public void testOnlyAssignsPartitionsFromSubscribedTopics() {
+    @ParameterizedTest(name = TEST_NAME_WITH_RACK_CONFIG)
+    @EnumSource(RackConfig.class)
+    public void testOnlyAssignsPartitionsFromSubscribedTopics(RackConfig 
rackConfig) {
+        initializeRacks(rackConfig);
         String otherTopic = "other";
 
-        Map<String, Integer> partitionsPerTopic = new HashMap<>();
-        partitionsPerTopic.put(topic1, 3);
-        partitionsPerTopic.put(otherTopic, 3);
+        Map<String, List<PartitionInfo>> partitionsPerTopic = new HashMap<>();
+        partitionsPerTopic.put(topic1, partitionInfos(topic1, 3));
+        partitionsPerTopic.put(otherTopic, partitionInfos(otherTopic, 3));
 
-        Map<String, List<TopicPartition>> assignment = 
assignor.assign(partitionsPerTopic,
-                Collections.singletonMap(consumer1, new 
Subscription(topics(topic1))));
+        Map<String, List<TopicPartition>> assignment = 
assignor.assignPartitions(partitionsPerTopic,
+                Collections.singletonMap(consumer1, 
subscription(topics(topic1), 0)));
         assertEquals(Collections.singleton(consumer1), assignment.keySet());
         assertAssignment(partitions(tp(topic1, 0), tp(topic1, 1), tp(topic1, 
2)), assignment.get(consumer1));
     }
 
-    @Test
-    public void testOneConsumerMultipleTopics() {
-        Map<String, Integer> partitionsPerTopic = 
setupPartitionsPerTopicWithTwoTopics(1, 2);
+    @ParameterizedTest(name = TEST_NAME_WITH_RACK_CONFIG)
+    @EnumSource(RackConfig.class)
+    public void testOneConsumerMultipleTopics(RackConfig rackConfig) {
+        initializeRacks(rackConfig);
+        Map<String, List<PartitionInfo>> partitionsPerTopic = 
setupPartitionsPerTopicWithTwoTopics(1, 2);
 
-        Map<String, List<TopicPartition>> assignment = 
assignor.assign(partitionsPerTopic,
-                Collections.singletonMap(consumer1, new 
Subscription(topics(topic1, topic2))));
+        Map<String, List<TopicPartition>> assignment = 
assignor.assignPartitions(partitionsPerTopic,
+                Collections.singletonMap(consumer1, 
subscription(topics(topic1, topic2), 0)));
 
         assertEquals(Collections.singleton(consumer1), assignment.keySet());
         assertAssignment(partitions(tp(topic1, 0), tp(topic2, 0), tp(topic2, 
1)), assignment.get(consumer1));
     }
 
-    @Test
-    public void testTwoConsumersOneTopicOnePartition() {
-        Map<String, Integer> partitionsPerTopic = new HashMap<>();
-        partitionsPerTopic.put(topic1, 1);
+    @ParameterizedTest(name = TEST_NAME_WITH_RACK_CONFIG)
+    @EnumSource(RackConfig.class)
+    public void testTwoConsumersOneTopicOnePartition(RackConfig rackConfig) {
+        initializeRacks(rackConfig);
+        Map<String, List<PartitionInfo>> partitionsPerTopic = new HashMap<>();
+        partitionsPerTopic.put(topic1, partitionInfos(topic1, 1));
 
         Map<String, Subscription> consumers = new HashMap<>();
-        consumers.put(consumer1, new Subscription(topics(topic1)));
-        consumers.put(consumer2, new Subscription(topics(topic1)));
+        consumers.put(consumer1, subscription(topics(topic1), 0));
+        consumers.put(consumer2, subscription(topics(topic1), 1));
 
-        Map<String, List<TopicPartition>> assignment = 
assignor.assign(partitionsPerTopic, consumers);
+        Map<String, List<TopicPartition>> assignment = 
assignor.assignPartitions(partitionsPerTopic, consumers);
         assertAssignment(partitions(tp(topic1, 0)), assignment.get(consumer1));
         assertAssignment(Collections.emptyList(), assignment.get(consumer2));
     }
 
 
-    @Test
-    public void testTwoConsumersOneTopicTwoPartitions() {
-        Map<String, Integer> partitionsPerTopic = new HashMap<>();
-        partitionsPerTopic.put(topic1, 2);
+    @ParameterizedTest(name = TEST_NAME_WITH_RACK_CONFIG)
+    @EnumSource(RackConfig.class)
+    public void testTwoConsumersOneTopicTwoPartitions(RackConfig rackConfig) {
+        initializeRacks(rackConfig);
+        Map<String, List<PartitionInfo>> partitionsPerTopic = new HashMap<>();
+        partitionsPerTopic.put(topic1, partitionInfos(topic1, 2));
 
         Map<String, Subscription> consumers = new HashMap<>();
-        consumers.put(consumer1, new Subscription(topics(topic1)));
-        consumers.put(consumer2, new Subscription(topics(topic1)));
+        consumers.put(consumer1, subscription(topics(topic1), 0));
+        consumers.put(consumer2, subscription(topics(topic1), 1));
 
-        Map<String, List<TopicPartition>> assignment = 
assignor.assign(partitionsPerTopic, consumers);
+        Map<String, List<TopicPartition>> assignment = 
assignor.assignPartitions(partitionsPerTopic, consumers);
         assertAssignment(partitions(tp(topic1, 0)), assignment.get(consumer1));
         assertAssignment(partitions(tp(topic1, 1)), assignment.get(consumer2));
     }
 
-    @Test
-    public void testMultipleConsumersMixedTopics() {
-        Map<String, Integer> partitionsPerTopic = 
setupPartitionsPerTopicWithTwoTopics(3, 2);
+    @ParameterizedTest(name = TEST_NAME_WITH_RACK_CONFIG)
+    @EnumSource(RackConfig.class)
+    public void testMultipleConsumersMixedTopics(RackConfig rackConfig) {
+        initializeRacks(rackConfig);
+        Map<String, List<PartitionInfo>> partitionsPerTopic = 
setupPartitionsPerTopicWithTwoTopics(3, 2);
 
         Map<String, Subscription> consumers = new HashMap<>();
-        consumers.put(consumer1, new Subscription(topics(topic1)));
-        consumers.put(consumer2, new Subscription(topics(topic1, topic2)));
-        consumers.put(consumer3, new Subscription(topics(topic1)));
+        consumers.put(consumer1, subscription(topics(topic1), 0));
+        consumers.put(consumer2, subscription(topics(topic1, topic2), 1));
+        consumers.put(consumer3, subscription(topics(topic1), 2));
 
-        Map<String, List<TopicPartition>> assignment = 
assignor.assign(partitionsPerTopic, consumers);
+        Map<String, List<TopicPartition>> assignment = 
assignor.assignPartitions(partitionsPerTopic, consumers);
         assertAssignment(partitions(tp(topic1, 0)), assignment.get(consumer1));
         assertAssignment(partitions(tp(topic1, 1), tp(topic2, 0), tp(topic2, 
1)), assignment.get(consumer2));
         assertAssignment(partitions(tp(topic1, 2)), assignment.get(consumer3));
     }
 
-    @Test
-    public void testTwoConsumersTwoTopicsSixPartitions() {
+    @ParameterizedTest(name = TEST_NAME_WITH_RACK_CONFIG)
+    @EnumSource(RackConfig.class)
+    public void testTwoConsumersTwoTopicsSixPartitions(RackConfig rackConfig) {
+        initializeRacks(rackConfig);
         String topic1 = "topic1";
         String topic2 = "topic2";
         String consumer1 = "consumer1";
         String consumer2 = "consumer2";
 
-        Map<String, Integer> partitionsPerTopic = 
setupPartitionsPerTopicWithTwoTopics(3, 3);
+        Map<String, List<PartitionInfo>> partitionsPerTopic = 
setupPartitionsPerTopicWithTwoTopics(3, 3);
 
         Map<String, Subscription> consumers = new HashMap<>();
-        consumers.put(consumer1, new Subscription(topics(topic1, topic2)));
-        consumers.put(consumer2, new Subscription(topics(topic1, topic2)));
+        consumers.put(consumer1, subscription(topics(topic1, topic2), 0));
+        consumers.put(consumer2, subscription(topics(topic1, topic2), 1));
 
-        Map<String, List<TopicPartition>> assignment = 
assignor.assign(partitionsPerTopic, consumers);
+        Map<String, List<TopicPartition>> assignment = 
assignor.assignPartitions(partitionsPerTopic, consumers);
         assertAssignment(partitions(tp(topic1, 0), tp(topic1, 1), tp(topic2, 
0), tp(topic2, 1)), assignment.get(consumer1));
         assertAssignment(partitions(tp(topic1, 2), tp(topic2, 2)), 
assignment.get(consumer2));
     }
 
-    @Test
-    public void testTwoStaticConsumersTwoTopicsSixPartitions() {
+    @ParameterizedTest(name = TEST_NAME_WITH_RACK_CONFIG)
+    @EnumSource(RackConfig.class)
+    public void testTwoStaticConsumersTwoTopicsSixPartitions(RackConfig 
rackConfig) {
+        initializeRacks(rackConfig);
         // although consumer high has a higher rank than consumer low, the 
comparison happens on
         // instance id level.
         String consumerIdLow = "consumer-b";
         String consumerIdHigh = "consumer-a";
 
-        Map<String, Integer> partitionsPerTopic = 
setupPartitionsPerTopicWithTwoTopics(3, 3);
+        Map<String, List<PartitionInfo>> partitionsPerTopic = 
setupPartitionsPerTopicWithTwoTopics(3, 3);
 
         Map<String, Subscription> consumers = new HashMap<>();
-        Subscription consumerLowSubscription = new Subscription(topics(topic1, 
topic2),
-                                                              null,
-                                                              
Collections.emptyList());
+        Subscription consumerLowSubscription = subscription(topics(topic1, 
topic2), 0);
         consumerLowSubscription.setGroupInstanceId(Optional.of(instance1));
         consumers.put(consumerIdLow, consumerLowSubscription);
-        Subscription consumerHighSubscription = new 
Subscription(topics(topic1, topic2),
-                                                              null,
-                                                              
Collections.emptyList());
+        Subscription consumerHighSubscription = subscription(topics(topic1, 
topic2), 1);
         consumerHighSubscription.setGroupInstanceId(Optional.of(instance2));
         consumers.put(consumerIdHigh, consumerHighSubscription);
-        Map<String, List<TopicPartition>> assignment = 
assignor.assign(partitionsPerTopic, consumers);
+        Map<String, List<TopicPartition>> assignment = 
assignor.assignPartitions(partitionsPerTopic, consumers);
         assertAssignment(partitions(tp(topic1, 0), tp(topic1, 1), tp(topic2, 
0), tp(topic2, 1)), assignment.get(consumerIdLow));
         assertAssignment(partitions(tp(topic1, 2), tp(topic2, 2)), 
assignment.get(consumerIdHigh));
     }
 
-    @Test
-    public void 
testOneStaticConsumerAndOneDynamicConsumerTwoTopicsSixPartitions() {
+    @ParameterizedTest(name = TEST_NAME_WITH_RACK_CONFIG)
+    @EnumSource(RackConfig.class)
+    public void 
testOneStaticConsumerAndOneDynamicConsumerTwoTopicsSixPartitions(RackConfig 
rackConfig) {
+        initializeRacks(rackConfig);
         // although consumer high has a higher rank than low, consumer low 
will win the comparison
         // because it has instance id while consumer 2 doesn't.
         String consumerIdLow = "consumer-b";
         String consumerIdHigh = "consumer-a";
 
-        Map<String, Integer> partitionsPerTopic = 
setupPartitionsPerTopicWithTwoTopics(3, 3);
+        Map<String, List<PartitionInfo>> partitionsPerTopic = 
setupPartitionsPerTopicWithTwoTopics(3, 3);
 
         Map<String, Subscription> consumers = new HashMap<>();
 
-        Subscription consumerLowSubscription = new Subscription(topics(topic1, 
topic2),
-                                                              null,
-                                                              
Collections.emptyList());
+        Subscription consumerLowSubscription = subscription(topics(topic1, 
topic2), 0);
         consumerLowSubscription.setGroupInstanceId(Optional.of(instance1));
         consumers.put(consumerIdLow, consumerLowSubscription);
-        consumers.put(consumerIdHigh, new Subscription(topics(topic1, 
topic2)));
+        consumers.put(consumerIdHigh, subscription(topics(topic1, topic2), 1));
 
-        Map<String, List<TopicPartition>> assignment = 
assignor.assign(partitionsPerTopic, consumers);
+        Map<String, List<TopicPartition>> assignment = 
assignor.assignPartitions(partitionsPerTopic, consumers);
         assertAssignment(partitions(tp(topic1, 0), tp(topic1, 1), tp(topic2, 
0), tp(topic2, 1)), assignment.get(consumerIdLow));
         assertAssignment(partitions(tp(topic1, 2), tp(topic2, 2)), 
assignment.get(consumerIdHigh));
     }
 
-    @Test
-    public void testStaticMemberRangeAssignmentPersistent() {
-        Map<String, Integer> partitionsPerTopic = 
setupPartitionsPerTopicWithTwoTopics(5, 4);
+    @ParameterizedTest(name = TEST_NAME_WITH_RACK_CONFIG)
+    @EnumSource(RackConfig.class)
+    public void testStaticMemberRangeAssignmentPersistent(RackConfig 
rackConfig) {
+        initializeRacks(rackConfig, 5);
+        Map<String, List<PartitionInfo>> partitionsPerTopic = 
setupPartitionsPerTopicWithTwoTopics(5, 4);
 
         Map<String, Subscription> consumers = new HashMap<>();
+        int consumerIndex = 0;
         for (MemberInfo m : staticMemberInfos) {
-            Subscription subscription = new Subscription(topics(topic1, 
topic2),
-                                                         null,
-                                                         
Collections.emptyList());
+            Subscription subscription = subscription(topics(topic1, topic2), 
consumerIndex++);
             subscription.setGroupInstanceId(m.groupInstanceId);
             consumers.put(m.memberId, subscription);
         }
         // Consumer 4 is a dynamic member.
         String consumer4 = "consumer4";
-        consumers.put(consumer4, new Subscription(topics(topic1, topic2)));
+        consumers.put(consumer4, subscription(topics(topic1, topic2), 
consumerIndex++));
 
         Map<String, List<TopicPartition>> expectedAssignment = new HashMap<>();
         // Have 3 static members instance1, instance2, instance3 to be 
persistent
@@ -250,29 +286,30 @@ public class RangeAssignorTest {
         expectedAssignment.put(consumer3, partitions(tp(topic1, 3), tp(topic2, 
2)));
         expectedAssignment.put(consumer4, partitions(tp(topic1, 4), tp(topic2, 
3)));
 
-        Map<String, List<TopicPartition>> assignment = 
assignor.assign(partitionsPerTopic, consumers);
+        Map<String, List<TopicPartition>> assignment = 
assignor.assignPartitions(partitionsPerTopic, consumers);
         assertEquals(expectedAssignment, assignment);
 
         // Replace dynamic member 4 with a new dynamic member 5.
         consumers.remove(consumer4);
         String consumer5 = "consumer5";
-        consumers.put(consumer5, new Subscription(topics(topic1, topic2)));
+        consumers.put(consumer5, subscription(topics(topic1, topic2), 
consumerIndex++));
 
         expectedAssignment.remove(consumer4);
         expectedAssignment.put(consumer5, partitions(tp(topic1, 4), tp(topic2, 
3)));
-        assignment = assignor.assign(partitionsPerTopic, consumers);
+        assignment = assignor.assignPartitions(partitionsPerTopic, consumers);
         assertEquals(expectedAssignment, assignment);
     }
 
-    @Test
-    public void 
testStaticMemberRangeAssignmentPersistentAfterMemberIdChanges() {
-        Map<String, Integer> partitionsPerTopic = 
setupPartitionsPerTopicWithTwoTopics(5, 5);
+    @ParameterizedTest(name = TEST_NAME_WITH_RACK_CONFIG)
+    @EnumSource(RackConfig.class)
+    public void 
testStaticMemberRangeAssignmentPersistentAfterMemberIdChanges(RackConfig 
rackConfig) {
+        initializeRacks(rackConfig);
+        Map<String, List<PartitionInfo>> partitionsPerTopic = 
setupPartitionsPerTopicWithTwoTopics(5, 5);
 
         Map<String, Subscription> consumers = new HashMap<>();
+        int consumerIndex = 0;
         for (MemberInfo m : staticMemberInfos) {
-            Subscription subscription = new Subscription(topics(topic1, 
topic2),
-                                                         null,
-                                                         
Collections.emptyList());
+            Subscription subscription = subscription(topics(topic1, topic2), 
consumerIndex++);
             subscription.setGroupInstanceId(m.groupInstanceId);
             consumers.put(m.memberId, subscription);
         }
@@ -302,10 +339,186 @@ public class RangeAssignorTest {
         assertEquals(staticAssignment, newStaticAssignment);
     }
 
-    static Map<String, List<TopicPartition>> 
checkStaticAssignment(AbstractPartitionAssignor assignor,
-                                                                   Map<String, 
Integer> partitionsPerTopic,
-                                                                   Map<String, 
Subscription> consumers) {
-        Map<String, List<TopicPartition>> assignmentByMemberId = 
assignor.assign(partitionsPerTopic, consumers);
+    @Test
+    public void 
testRackAwareStaticMemberRangeAssignmentPersistentAfterMemberIdChanges() {
+        initializeRacks(RackConfig.BROKER_AND_CONSUMER_RACK);
+        Map<String, List<PartitionInfo>> partitionsPerTopic = new HashMap<>();
+        int replicationFactor = 2;
+        int numBrokerRacks = 3;
+        partitionsPerTopic.put(topic1, 
AbstractPartitionAssignorTest.partitionInfos(topic1, 5, replicationFactor, 
numBrokerRacks, 0));
+        partitionsPerTopic.put(topic2,  
AbstractPartitionAssignorTest.partitionInfos(topic2, 5, replicationFactor, 
numBrokerRacks, 0));
+        List<MemberInfo> staticMemberInfos = new ArrayList<>();
+        staticMemberInfos.add(new MemberInfo(consumer1, 
Optional.of(instance1), Optional.of(ALL_RACKS[0])));
+        staticMemberInfos.add(new MemberInfo(consumer2, 
Optional.of(instance2), Optional.of(ALL_RACKS[1])));
+        staticMemberInfos.add(new MemberInfo(consumer3, 
Optional.of(instance3), Optional.of(ALL_RACKS[2])));
+
+        Map<String, Subscription> consumers = new HashMap<>();
+        int consumerIndex = 0;
+        for (MemberInfo m : staticMemberInfos) {
+            Subscription subscription = subscription(topics(topic1, topic2), 
consumerIndex++);
+            subscription.setGroupInstanceId(m.groupInstanceId);
+            consumers.put(m.memberId, subscription);
+        }
+        Map<String, List<TopicPartition>> expectedInstanceAssignment = new 
HashMap<>();
+        expectedInstanceAssignment.put(instance1,
+                partitions(tp(topic1, 0), tp(topic1, 2), tp(topic2, 0), 
tp(topic2, 2)));
+        expectedInstanceAssignment.put(instance2,
+                partitions(tp(topic1, 1), tp(topic1, 3), tp(topic2, 1), 
tp(topic2, 3)));
+        expectedInstanceAssignment.put(instance3,
+                partitions(tp(topic1, 4), tp(topic2, 4)));
+
+        Map<String, List<TopicPartition>> staticAssignment =
+                checkStaticAssignment(assignor, partitionsPerTopic, consumers);
+        assertEquals(expectedInstanceAssignment, staticAssignment);
+
+        // Now switch the member.id fields for each member info, the 
assignment should
+        // stay the same as last time.
+        String consumer4 = "consumer4";
+        String consumer5 = "consumer5";
+        consumers.put(consumer4, consumers.get(consumer3));
+        consumers.remove(consumer3);
+        consumers.put(consumer5, consumers.get(consumer2));
+        consumers.remove(consumer2);
+
+        Map<String, List<TopicPartition>> newStaticAssignment =
+                checkStaticAssignment(assignor, partitionsPerTopic, consumers);
+        assertEquals(staticAssignment, newStaticAssignment);
+    }
+
+    @Test
+    public void testRackAwareAssignmentWithUniformSubscription() {
+        Map<String, Integer> topics = mkMap(mkEntry("t1", 6), mkEntry("t2", 
7), mkEntry("t3", 2));
+        List<String> allTopics = asList("t1", "t2", "t3");
+        List<List<String>> consumerTopics = asList(allTopics, allTopics, 
allTopics);
+
+        // Verify combinations where rack-aware logic is not used.
+        verifyNonRackAwareAssignment(topics, consumerTopics,
+                asList("t1-0, t1-1, t2-0, t2-1, t2-2, t3-0", "t1-2, t1-3, 
t2-3, t2-4, t3-1", "t1-4, t1-5, t2-5, t2-6"));
+
+        // Verify best-effort rack-aware assignment for lower replication 
factor where racks have a subset of partitions.
+        verifyRackAssignment(assignor, topics, 1, racks(3), racks(3), 
consumerTopics,
+                asList("t1-0, t1-3, t2-0, t2-3, t2-6", "t1-1, t1-4, t2-1, 
t2-4, t3-0", "t1-2, t1-5, t2-2, t2-5, t3-1"), 0);
+        verifyRackAssignment(assignor, topics, 2, racks(3), racks(3), 
consumerTopics,
+                asList("t1-0, t1-2, t2-0, t2-2, t2-3, t3-1", "t1-1, t1-3, 
t2-1, t2-4, t3-0", "t1-4, t1-5, t2-5, t2-6"), 1);
+
+        // One consumer on a rack with no partitions
+        verifyRackAssignment(assignor, topics, 3, racks(2), racks(3), 
consumerTopics,
+                asList("t1-0, t1-1, t2-0, t2-1, t2-2, t3-0", "t1-2, t1-3, 
t2-3, t2-4, t3-1", "t1-4, t1-5, t2-5, t2-6"), 4);
+    }
+
+    @Test
+    public void testRackAwareAssignmentWithNonEqualSubscription() {
+        Map<String, Integer> topics = mkMap(mkEntry("t1", 6), mkEntry("t2", 
7), mkEntry("t3", 2));
+        List<String> allTopics = asList("t1", "t2", "t3");
+        List<List<String>> consumerTopics = asList(allTopics, allTopics, 
asList("t1", "t3"));
+
+        // Verify combinations where rack-aware logic is not used.
+        verifyNonRackAwareAssignment(topics, consumerTopics,
+                asList("t1-0, t1-1, t2-0, t2-1, t2-2, t2-3, t3-0", "t1-2, 
t1-3, t2-4, t2-5, t2-6, t3-1", "t1-4, t1-5"));
+
+        // Verify best-effort rack-aware assignment for lower replication 
factor where racks have a subset of partitions.
+        verifyRackAssignment(assignor, topics, 1, racks(3), racks(3), 
consumerTopics,
+                asList("t1-0, t1-3, t2-0, t2-2, t2-3, t2-6", "t1-1, t1-4, 
t2-1, t2-4, t2-5, t3-0", "t1-2, t1-5, t3-1"), 2);
+        verifyRackAssignment(assignor, topics, 2, racks(3), racks(3), 
consumerTopics,
+                asList("t1-0, t1-2, t2-0, t2-2, t2-3, t2-5, t3-1", "t1-1, 
t1-3, t2-1, t2-4, t2-6, t3-0", "t1-4, t1-5"), 0);
+
+        // One consumer on a rack with no partitions
+        verifyRackAssignment(assignor, topics, 3, racks(2), racks(3), 
consumerTopics,
+                asList("t1-0, t1-1, t2-0, t2-1, t2-2, t2-3, t3-0", "t1-2, 
t1-3, t2-4, t2-5, t2-6, t3-1", "t1-4, t1-5"), 2);
+    }
+
+    @Test
+    public void testRackAwareAssignmentWithUniformPartitions() {
+        Map<String, Integer> topics = mkMap(mkEntry("t1", 5), mkEntry("t2", 
5), mkEntry("t3", 5));
+        List<String> allTopics = asList("t1", "t2", "t3");
+        List<List<String>> consumerTopics = asList(allTopics, allTopics, 
allTopics);
+        List<String> nonRackAwareAssignment = asList(
+                "t1-0, t1-1, t2-0, t2-1, t3-0, t3-1",
+                "t1-2, t1-3, t2-2, t2-3, t3-2, t3-3",
+                "t1-4, t2-4, t3-4"
+        );
+
+        // Verify combinations where rack-aware logic is not used.
+        verifyNonRackAwareAssignment(topics, consumerTopics, 
nonRackAwareAssignment);
+
+        // Verify that co-partitioning is prioritized over rack-alignment
+        verifyRackAssignment(assignor, topics, 1, racks(3), racks(3), 
consumerTopics, nonRackAwareAssignment, 10);
+        verifyRackAssignment(assignor, topics, 2, racks(3), racks(3), 
consumerTopics, nonRackAwareAssignment, 5);
+        verifyRackAssignment(assignor, topics, 3, racks(2), racks(3), 
consumerTopics, nonRackAwareAssignment, 3);
+    }
+
+    @Test
+    public void 
testRackAwareAssignmentWithUniformPartitionsNonEqualSubscription() {
+        Map<String, Integer> topics = mkMap(mkEntry("t1", 5), mkEntry("t2", 
5), mkEntry("t3", 5));
+        List<String> allTopics = asList("t1", "t2", "t3");
+        List<List<String>> consumerTopics = asList(allTopics, allTopics, 
asList("t1", "t3"));
+
+        // Verify combinations where rack-aware logic is not used.
+        verifyNonRackAwareAssignment(topics, consumerTopics,
+                asList("t1-0, t1-1, t2-0, t2-1, t2-2, t3-0, t3-1", "t1-2, 
t1-3, t2-3, t2-4, t3-2, t3-3", "t1-4, t3-4"));
+
+        // Verify that co-partitioning is prioritized over rack-alignment for 
topics with equal subscriptions
+        verifyRackAssignment(assignor, topics, 1, racks(3), racks(3), 
consumerTopics,
+                asList("t1-0, t1-1, t2-0, t2-1, t2-4, t3-0, t3-1", "t1-2, 
t1-3, t2-2, t2-3, t3-2, t3-3", "t1-4, t3-4"), 9);
+        verifyRackAssignment(assignor, topics, 2, racks(3), racks(3), 
consumerTopics,
+                asList("t1-2, t2-0, t2-1, t2-3, t3-2", "t1-0, t1-3, t2-2, 
t2-4, t3-0, t3-3", "t1-1, t1-4, t3-1, t3-4"), 0);
+
+        // One consumer on a rack with no partitions
+        verifyRackAssignment(assignor, topics, 3, racks(2), racks(3), 
consumerTopics,
+                asList("t1-0, t1-1, t2-0, t2-1, t2-2, t3-0, t3-1", "t1-2, 
t1-3, t2-3, t2-4, t3-2, t3-3", "t1-4, t3-4"), 2);
+    }
+
+    @Test
+    public void testRackAwareAssignmentWithCoPartitioning() {
+        Map<String, Integer> topics = mkMap(mkEntry("t1", 6), mkEntry("t2", 
6), mkEntry("t3", 2), mkEntry("t4", 2));
+        List<List<String>> consumerTopics = asList(asList("t1", "t2"), 
asList("t1", "t2"), asList("t3", "t4"), asList("t3", "t4"));
+        List<String> consumerRacks = asList(ALL_RACKS[0], ALL_RACKS[1], 
ALL_RACKS[1], ALL_RACKS[0]);
+        List<String> nonRackAwareAssignment = asList(
+                "t1-0, t1-1, t1-2, t2-0, t2-1, t2-2",
+                "t1-3, t1-4, t1-5, t2-3, t2-4, t2-5",
+                "t3-0, t4-0",
+                "t3-1, t4-1"
+        );
+
+        verifyRackAssignment(assignor, topics, 3, racks(2), consumerRacks, 
consumerTopics, nonRackAwareAssignment, -1);
+        verifyRackAssignment(assignor, topics, 3, racks(2), consumerRacks, 
consumerTopics, nonRackAwareAssignment, -1);
+        verifyRackAssignment(assignor, topics, 2, racks(2), consumerRacks, 
consumerTopics, nonRackAwareAssignment, 0);
+        verifyRackAssignment(assignor, topics, 1, racks(2), consumerRacks, 
consumerTopics,
+                asList("t1-0, t1-2, t1-4, t2-0, t2-2, t2-4", "t1-1, t1-3, 
t1-5, t2-1, t2-3, t2-5", "t3-1, t4-1", "t3-0, t4-0"), 0);
+
+        List<String> allTopics = asList("t1", "t2", "t3", "t4");
+        consumerTopics = asList(allTopics, allTopics, allTopics, allTopics);
+        nonRackAwareAssignment = asList(
+                "t1-0, t1-1, t2-0, t2-1, t3-0, t4-0",
+                "t1-2, t1-3, t2-2, t2-3, t3-1, t4-1",
+                "t1-4, t2-4",
+                "t1-5, t2-5"
+        );
+        verifyRackAssignment(assignor, topics, 3, racks(2), consumerRacks, 
consumerTopics, nonRackAwareAssignment, -1);
+        verifyRackAssignment(assignor, topics, 3, racks(2), consumerRacks, 
consumerTopics, nonRackAwareAssignment, -1);
+        verifyRackAssignment(assignor, topics, 2, racks(2), consumerRacks, 
consumerTopics, nonRackAwareAssignment, 0);
+        verifyRackAssignment(assignor, topics, 1, racks(2), consumerRacks, 
consumerTopics,
+                asList("t1-0, t1-2, t2-0, t2-2, t3-0, t4-0", "t1-1, t1-3, 
t2-1, t2-3, t3-1, t4-1", "t1-4, t2-4", "t1-5, t2-5"), 4);
+        verifyRackAssignment(assignor, topics, 1, racks(3), consumerRacks, 
consumerTopics, nonRackAwareAssignment, 10);
+    }
+
+    private void verifyNonRackAwareAssignment(Map<String, Integer> topics, 
List<List<String>> consumerTopics, List<String> nonRackAwareAssignment) {
+        verifyRackAssignment(assignor, topics, 3, nullRacks(3), racks(3), 
consumerTopics, nonRackAwareAssignment, -1);
+        verifyRackAssignment(assignor, topics, 3, racks(3), nullRacks(3), 
consumerTopics, nonRackAwareAssignment, -1);
+        verifyRackAssignment(assignor, topics, 3, racks(3), racks(3), 
consumerTopics, nonRackAwareAssignment, 0);
+        verifyRackAssignment(assignor, topics, 4, racks(4), racks(3), 
consumerTopics, nonRackAwareAssignment, 0);
+        verifyRackAssignment(assignor, topics, 3, racks(3), asList("d", "e", 
"f"), consumerTopics, nonRackAwareAssignment, -1);
+        verifyRackAssignment(assignor, topics, 3, racks(3), asList(null, "e", 
"f"), consumerTopics, nonRackAwareAssignment, -1);
+
+        AbstractPartitionAssignorTest.preferRackAwareLogic(assignor, true);
+        verifyRackAssignment(assignor, topics, 3, racks(3), racks(3), 
consumerTopics, nonRackAwareAssignment, 0);
+        AbstractPartitionAssignorTest.preferRackAwareLogic(assignor, false);
+    }
+
+    private static Map<String, List<TopicPartition>> 
checkStaticAssignment(AbstractPartitionAssignor assignor,
+                                                                           
Map<String, List<PartitionInfo>> partitionsPerTopic,
+                                                                           
Map<String, Subscription> consumers) {
+        Map<String, List<TopicPartition>> assignmentByMemberId = 
assignor.assignPartitions(partitionsPerTopic, consumers);
         Map<String, List<TopicPartition>> assignmentByInstanceId = new 
HashMap<>();
         for (Map.Entry<String, Subscription> entry : consumers.entrySet()) {
             String memberId = entry.getKey();
@@ -320,13 +533,23 @@ public class RangeAssignorTest {
         assertEquals(new HashSet<>(expected), new HashSet<>(actual));
     }
 
-    private Map<String, Integer> setupPartitionsPerTopicWithTwoTopics(int 
numberOfPartitions1, int numberOfPartitions2) {
-        Map<String, Integer> partitionsPerTopic = new HashMap<>();
-        partitionsPerTopic.put(topic1, numberOfPartitions1);
-        partitionsPerTopic.put(topic2, numberOfPartitions2);
+    private Map<String, List<PartitionInfo>> 
setupPartitionsPerTopicWithTwoTopics(int numberOfPartitions1, int 
numberOfPartitions2) {
+        Map<String, List<PartitionInfo>> partitionsPerTopic = new HashMap<>();
+        partitionsPerTopic.put(topic1, partitionInfos(topic1, 
numberOfPartitions1));
+        partitionsPerTopic.put(topic2, partitionInfos(topic2, 
numberOfPartitions2));
         return partitionsPerTopic;
     }
 
+    private List<PartitionInfo> partitionInfos(String topic, int 
numberOfPartitions) {
+        return AbstractPartitionAssignorTest.partitionInfos(topic, 
numberOfPartitions, replicationFactor, numBrokerRacks, 0);
+    }
+
+    private Subscription subscription(List<String> topics, int consumerIndex) {
+        int numRacks = numBrokerRacks > 0 ? numBrokerRacks : ALL_RACKS.length;
+        Optional<String> rackId = Optional.ofNullable(hasConsumerRack ? 
ALL_RACKS[consumerIndex % numRacks] : null);
+        return new Subscription(topics, null, Collections.emptyList(), -1, 
rackId);
+    }
+
     private static List<String> topics(String... topics) {
         return Arrays.asList(topics);
     }
@@ -338,4 +561,17 @@ public class RangeAssignorTest {
     private static TopicPartition tp(String topic, int partition) {
         return new TopicPartition(topic, partition);
     }
+
+    void initializeRacks(RackConfig rackConfig) {
+        initializeRacks(rackConfig, 3);
+    }
+
+    void initializeRacks(RackConfig rackConfig, int maxConsumers) {
+        this.replicationFactor = maxConsumers;
+        this.numBrokerRacks = rackConfig != RackConfig.NO_BROKER_RACK ? 
maxConsumers : 0;
+        this.hasConsumerRack = rackConfig != RackConfig.NO_CONSUMER_RACK;
+        // Rack and consumer ordering are the same in all the tests, so we can 
verify
+        // rack-aware logic using the same tests.
+        AbstractPartitionAssignorTest.preferRackAwareLogic(assignor, true);
+    }
 }
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/consumer/RoundRobinAssignorTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/consumer/RoundRobinAssignorTest.java
index 19cd68c3d90..256491b9d7a 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/consumer/RoundRobinAssignorTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/consumer/RoundRobinAssignorTest.java
@@ -29,8 +29,8 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
+import java.util.stream.Collectors;
 
-import static 
org.apache.kafka.clients.consumer.RangeAssignorTest.checkStaticAssignment;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
@@ -335,4 +335,13 @@ public class RoundRobinAssignorTest {
         partitionsPerTopic.put(topic2, numberOfPartitions2);
         return partitionsPerTopic;
     }
+
+    private static Map<String, List<TopicPartition>> 
checkStaticAssignment(AbstractPartitionAssignor assignor,
+                                                                           
Map<String, Integer> partitionsPerTopic,
+                                                                           
Map<String, Subscription> consumers) {
+        Map<String, List<TopicPartition>> assignmentByMemberId = 
assignor.assign(partitionsPerTopic, consumers);
+        return consumers.entrySet().stream()
+                .filter(e -> e.getValue().groupInstanceId().isPresent())
+                .collect(Collectors.toMap(e -> 
e.getValue().groupInstanceId().get(), e -> 
assignmentByMemberId.get(e.getKey())));
+    }
 }
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractPartitionAssignorTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractPartitionAssignorTest.java
index 9d0423dc142..3fe43048a28 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractPartitionAssignorTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractPartitionAssignorTest.java
@@ -16,21 +16,45 @@
  */
 package org.apache.kafka.clients.consumer.internals;
 
+import 
org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Subscription;
+import org.apache.kafka.clients.consumer.RangeAssignor;
 import 
org.apache.kafka.clients.consumer.internals.AbstractPartitionAssignor.MemberInfo;
+import org.apache.kafka.common.Node;
+import org.apache.kafka.common.PartitionInfo;
+import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.utils.Utils;
 import org.junit.jupiter.api.Test;
 
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
 import java.util.Optional;
 import java.util.Random;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 
+import static 
org.apache.kafka.clients.consumer.internals.AbstractStickyAssignor.DEFAULT_GENERATION;
 import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
 
 public class AbstractPartitionAssignorTest {
 
+    public static final String TEST_NAME_WITH_RACK_CONFIG = 
"{displayName}.rackConfig = {0}";
+    public static final String TEST_NAME_WITH_CONSUMER_RACK = 
"{displayName}.hasConsumerRack = {0}";
+    public static final String[] ALL_RACKS = {"a", "b", "c", "d", "e", "f"};
+
+    public enum RackConfig {
+        NO_BROKER_RACK,
+        NO_CONSUMER_RACK,
+        BROKER_AND_CONSUMER_RACK
+    }
+
     @Test
     public void testMemberInfoSortingWithoutGroupInstanceId() {
         MemberInfo m1 = new MemberInfo("a", Optional.empty());
@@ -86,4 +110,149 @@ public class AbstractPartitionAssignorTest {
         Collections.shuffle(memberInfoList);
         assertEquals(staticMemberList, Utils.sorted(memberInfoList));
     }
+
+    @Test
+    public void testUseRackAwareAssignment() {
+        AbstractPartitionAssignor assignor = new RangeAssignor();
+        String[] racks = new String[] {"a", "b", "c"};
+        Set<String> allRacks = Utils.mkSet(racks);
+        Set<String> twoRacks = Utils.mkSet("a", "b");
+        Map<TopicPartition, Set<String>> partitionsOnAllRacks = new 
HashMap<>();
+        Map<TopicPartition, Set<String>> partitionsOnSubsetOfRacks = new 
HashMap<>();
+        for (int i = 0; i < 10; i++) {
+            TopicPartition tp = new TopicPartition("topic", i);
+            partitionsOnAllRacks.put(tp, allRacks);
+            partitionsOnSubsetOfRacks.put(tp, Utils.mkSet(racks[i % 
racks.length]));
+        }
+        assertFalse(assignor.useRackAwareAssignment(Collections.emptySet(), 
Collections.emptySet(), partitionsOnAllRacks));
+        assertFalse(assignor.useRackAwareAssignment(Collections.emptySet(), 
allRacks, partitionsOnAllRacks));
+        assertFalse(assignor.useRackAwareAssignment(allRacks, 
Collections.emptySet(), Collections.emptyMap()));
+        assertFalse(assignor.useRackAwareAssignment(Utils.mkSet("d"), 
allRacks, partitionsOnAllRacks));
+        assertFalse(assignor.useRackAwareAssignment(allRacks, allRacks, 
partitionsOnAllRacks));
+        assertFalse(assignor.useRackAwareAssignment(twoRacks, allRacks, 
partitionsOnAllRacks));
+        assertFalse(assignor.useRackAwareAssignment(Utils.mkSet("a", "d"), 
allRacks, partitionsOnAllRacks));
+        assertTrue(assignor.useRackAwareAssignment(allRacks, allRacks, 
partitionsOnSubsetOfRacks));
+        assertTrue(assignor.useRackAwareAssignment(twoRacks, allRacks, 
partitionsOnSubsetOfRacks));
+        assertTrue(assignor.useRackAwareAssignment(Utils.mkSet("a", "d"), 
allRacks, partitionsOnSubsetOfRacks));
+
+        assignor.preferRackAwareLogic = true;
+        assertFalse(assignor.useRackAwareAssignment(Collections.emptySet(), 
Collections.emptySet(), partitionsOnAllRacks));
+        assertFalse(assignor.useRackAwareAssignment(Collections.emptySet(), 
allRacks, partitionsOnAllRacks));
+        assertFalse(assignor.useRackAwareAssignment(allRacks, 
Collections.emptySet(), Collections.emptyMap()));
+        assertFalse(assignor.useRackAwareAssignment(Utils.mkSet("d"), 
allRacks, partitionsOnAllRacks));
+        assertTrue(assignor.useRackAwareAssignment(allRacks, allRacks, 
partitionsOnAllRacks));
+        assertTrue(assignor.useRackAwareAssignment(twoRacks, allRacks, 
partitionsOnAllRacks));
+        assertTrue(assignor.useRackAwareAssignment(allRacks, allRacks, 
partitionsOnSubsetOfRacks));
+        assertTrue(assignor.useRackAwareAssignment(twoRacks, allRacks, 
partitionsOnSubsetOfRacks));
+    }
+
+    public static List<String> racks(int numRacks) {
+        List<String> racks = new ArrayList<>(numRacks);
+        for (int i = 0; i < numRacks; i++)
+            racks.add(ALL_RACKS[i % ALL_RACKS.length]);
+        return racks;
+    }
+
+    public static List<String> nullRacks(int numRacks) {
+        return Arrays.asList(new String[numRacks]);
+    }
+
+    public static void verifyRackAssignment(AbstractPartitionAssignor assignor,
+                                            Map<String, Integer> 
numPartitionsPerTopic,
+                                            int replicationFactor,
+                                            List<String> brokerRacks,
+                                            List<String> consumerRacks,
+                                            List<List<String>> consumerTopics,
+                                            List<String> expectedAssignments,
+                                            int numPartitionsWithRackMismatch) 
{
+        List<String> consumers = IntStream.range(0, 
consumerRacks.size()).mapToObj(i -> "consumer" + 
i).collect(Collectors.toList());
+        List<Subscription> subscriptions = subscriptions(consumerTopics, 
consumerRacks);
+        Map<String, List<PartitionInfo>> partitionsPerTopic = 
partitionsPerTopic(numPartitionsPerTopic, replicationFactor, brokerRacks);
+
+        Map<String, Subscription> subscriptionsByConsumer = new 
HashMap<>(consumers.size());
+        for (int i = 0; i < subscriptions.size(); i++)
+            subscriptionsByConsumer.put(consumers.get(i), 
subscriptions.get(i));
+
+        Map<String, String> expectedAssignment = new 
HashMap<>(consumers.size());
+        for (int i = 0; i < consumers.size(); i++)
+            expectedAssignment.put(consumers.get(i), 
expectedAssignments.get(i));
+
+        Map<String, List<TopicPartition>> assignment = 
assignor.assignPartitions(partitionsPerTopic, subscriptionsByConsumer);
+        Map<String, String> actualAssignment = assignment.entrySet().stream()
+                .collect(Collectors.toMap(Entry::getKey, e -> 
toSortedString(e.getValue())));
+        assertEquals(expectedAssignment, actualAssignment);
+
+        if (numPartitionsWithRackMismatch >= 0) {
+            List<TopicPartition> numMismatched = new ArrayList<>();
+            for (int i = 0; i < consumers.size(); i++) {
+                String rack = consumerRacks.get(i);
+                if (rack != null) {
+                    List<TopicPartition> partitions = 
assignment.get(consumers.get(i));
+                    for (TopicPartition tp : partitions) {
+                        PartitionInfo partitionInfo = 
partitionsPerTopic.get(tp.topic()).stream()
+                                .filter(p -> p.topic().equals(tp.topic()) && 
p.partition() == tp.partition())
+                                .findFirst().get();
+                        if 
(Arrays.stream(partitionInfo.replicas()).noneMatch(n -> rack.equals(n.rack())))
+                            numMismatched.add(tp);
+                    }
+                }
+            }
+            assertEquals(numPartitionsWithRackMismatch, numMismatched.size(), 
"Partitions with rack mismatch " + numMismatched);
+        }
+    }
+
+    private static String toSortedString(List<?> partitions) {
+        return 
Utils.join(partitions.stream().map(Object::toString).sorted().collect(Collectors.toList()),
 ", ");
+    }
+
+    private static List<Subscription> subscriptions(List<List<String>> 
consumerTopics, List<String> consumerRacks) {
+        List<Subscription> subscriptions = new 
ArrayList<>(consumerTopics.size());
+        for (int i = 0; i < consumerTopics.size(); i++)
+            subscriptions.add(new Subscription(consumerTopics.get(i), null, 
Collections.emptyList(), DEFAULT_GENERATION, 
Optional.ofNullable(consumerRacks.get(i))));
+        return subscriptions;
+    }
+
+    private static Map<String, List<PartitionInfo>> 
partitionsPerTopic(Map<String, Integer> numPartitionsPerTopic,
+                                                                       int 
replicationFactor,
+                                                                       
List<String> brokerRacks) {
+        Map<String, List<PartitionInfo>> partitionsPerTopic = new HashMap<>();
+        int nextIndex = 0;
+        for (Map.Entry<String, Integer> entry : 
numPartitionsPerTopic.entrySet()) {
+            String topic = entry.getKey();
+            int numPartitions = entry.getValue();
+            partitionsPerTopic.put(topic, partitionInfos(topic, numPartitions, 
replicationFactor, brokerRacks, nextIndex));
+            nextIndex += numPartitions;
+        }
+        return partitionsPerTopic;
+    }
+
+    private static List<PartitionInfo> partitionInfos(String topic, int 
numberOfPartitions, int replicationFactor, List<String> brokerRacks, int 
nextNodeIndex) {
+        int numBrokers = brokerRacks.size();
+        List<Node> nodes = new ArrayList<>(numBrokers);
+        for (int i = 0; i < brokerRacks.size(); i++) {
+            nodes.add(new Node(i, "", i, brokerRacks.get(i)));
+        }
+        List<PartitionInfo> partitionInfos = new 
ArrayList<>(numberOfPartitions);
+        for (int i = 0; i < numberOfPartitions; i++) {
+            Node[] replicas = new Node[replicationFactor];
+            for (int j = 0; j < replicationFactor; j++) {
+                replicas[j] = nodes.get((i + j + nextNodeIndex) % 
nodes.size());
+            }
+            partitionInfos.add(new PartitionInfo(topic, i, replicas[0], 
replicas, replicas));
+        }
+        return partitionInfos;
+    }
+
+    public static List<PartitionInfo> partitionInfos(String topic, int 
numberOfPartitions, int replicationFactor, int numBrokerRacks, int 
nextNodeIndex) {
+        int numBrokers = numBrokerRacks <= 0 ? replicationFactor : 
numBrokerRacks * replicationFactor;
+        List<String> brokerRacks = new ArrayList<>(numBrokers);
+        for (int i = 0; i < numBrokers; i++) {
+            brokerRacks.add(numBrokerRacks <= 0 ? null : ALL_RACKS[i % 
numBrokerRacks]);
+        }
+        return partitionInfos(topic, numberOfPartitions, replicationFactor, 
brokerRacks, nextNodeIndex);
+    }
+
+    public static void preferRackAwareLogic(AbstractPartitionAssignor 
assignor, boolean value) {
+        assignor.preferRackAwareLogic = value;
+    }
 }
diff --git 
a/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala 
b/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
index 017b950435d..91af0dd7d38 100644
--- a/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
+++ b/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
@@ -15,12 +15,18 @@ package kafka.api
 import java.time.Duration
 import java.util
 import java.util.Arrays.asList
+import java.util.concurrent.TimeUnit
+import java.util.concurrent.locks.ReentrantLock
 import java.util.regex.Pattern
 import java.util.{Locale, Optional, Properties}
+
+import kafka.server.{KafkaServer, QuotaType}
 import kafka.utils.TestUtils
+import org.apache.kafka.clients.admin.{NewPartitions, NewTopic}
 import org.apache.kafka.clients.consumer._
 import org.apache.kafka.clients.producer.{ProducerConfig, ProducerRecord}
 import org.apache.kafka.common.{MetricName, TopicPartition}
+import org.apache.kafka.common.config.TopicConfig
 import org.apache.kafka.common.errors.{InvalidGroupIdException, 
InvalidTopicException}
 import org.apache.kafka.common.header.Headers
 import org.apache.kafka.common.record.{CompressionType, TimestampType}
@@ -29,20 +35,12 @@ import org.apache.kafka.common.utils.Utils
 import org.apache.kafka.test.{MockConsumerInterceptor, MockProducerInterceptor}
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.Test
-
-import scala.jdk.CollectionConverters._
-import scala.collection.mutable.Buffer
-import kafka.server.QuotaType
-import kafka.server.KafkaServer
-import org.apache.kafka.clients.admin.NewPartitions
-import org.apache.kafka.clients.admin.NewTopic
-import org.apache.kafka.common.config.TopicConfig
 import org.junit.jupiter.params.ParameterizedTest
 import org.junit.jupiter.params.provider.ValueSource
 
-import java.util.concurrent.TimeUnit
-import java.util.concurrent.locks.ReentrantLock
 import scala.collection.mutable
+import scala.collection.mutable.Buffer
+import scala.jdk.CollectionConverters._
 
 /* We have some tests in this class instead of `BaseConsumerTest` in order to 
keep the build time under control. */
 class PlaintextConsumerTest extends BaseConsumerTest {
@@ -1949,22 +1947,4 @@ class PlaintextConsumerTest extends BaseConsumerTest {
 
     consumer2.close()
   }
-
-  @Test
-  def testConsumerRackIdPropagatedToPartitionAssignor(): Unit = {
-    consumerConfig.setProperty(ConsumerConfig.CLIENT_RACK_CONFIG, "rack-a")
-    
consumerConfig.setProperty(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG, 
classOf[RackAwareAssignor].getName)
-    val consumer = createConsumer()
-    consumer.subscribe(Set(topic).asJava)
-    awaitAssignment(consumer, Set(tp, tp2))
-  }
 }
-
-class RackAwareAssignor extends RoundRobinAssignor {
-  override def assign(partitionsPerTopic: util.Map[String, Integer], 
subscriptions: util.Map[String, ConsumerPartitionAssignor.Subscription]): 
util.Map[String, util.List[TopicPartition]] = {
-    assertEquals(1, subscriptions.size())
-    assertEquals(Optional.of("rack-a"), 
subscriptions.values.asScala.head.rackId)
-    super.assign(partitionsPerTopic, subscriptions)
-  }
-}
-
diff --git 
a/core/src/test/scala/integration/kafka/server/FetchFromFollowerIntegrationTest.scala
 
b/core/src/test/scala/integration/kafka/server/FetchFromFollowerIntegrationTest.scala
index 822099f605b..1f940efc422 100644
--- 
a/core/src/test/scala/integration/kafka/server/FetchFromFollowerIntegrationTest.scala
+++ 
b/core/src/test/scala/integration/kafka/server/FetchFromFollowerIntegrationTest.scala
@@ -18,7 +18,8 @@ package integration.kafka.server
 
 import kafka.server.{BaseFetchRequestTest, KafkaConfig}
 import kafka.utils.{TestInfoUtils, TestUtils}
-import org.apache.kafka.clients.consumer.{ConsumerConfig, KafkaConsumer}
+import org.apache.kafka.clients.consumer.{ConsumerConfig, KafkaConsumer, 
RangeAssignor}
+import org.apache.kafka.clients.producer.ProducerRecord
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.protocol.{ApiKeys, Errors}
 import org.apache.kafka.common.requests.FetchResponse
@@ -28,7 +29,8 @@ import org.junit.jupiter.api.{Test, Timeout}
 import org.junit.jupiter.params.ParameterizedTest
 import org.junit.jupiter.params.provider.ValueSource
 
-import java.util.Properties
+import java.util.{Collections, Properties}
+import java.util.concurrent.{Executors, TimeUnit}
 import scala.jdk.CollectionConverters._
 
 class FetchFromFollowerIntegrationTest extends BaseFetchRequestTest {
@@ -168,6 +170,66 @@ class FetchFromFollowerIntegrationTest extends 
BaseFetchRequestTest {
     }
   }
 
+  @Test
+  def testRackAwareRangeAssignor(): Unit = {
+    val partitionList = servers.indices.toList
+
+    val topicWithAllPartitionsOnAllRacks = "topicWithAllPartitionsOnAllRacks"
+    createTopic(topicWithAllPartitionsOnAllRacks, servers.size, servers.size)
+
+    // Racks are in order of broker ids, assign leaders in reverse order
+    val topicWithSingleRackPartitions = "topicWithSingleRackPartitions"
+    createTopicWithAssignment(topicWithSingleRackPartitions, 
partitionList.map(i => (i, Seq(servers.size - i - 1))).toMap)
+
+    // Create consumers with instance ids in ascending order, with racks in 
the same order.
+    
consumerConfig.setProperty(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG, 
classOf[RangeAssignor].getName)
+    val consumers = servers.map { server =>
+      consumerConfig.setProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, 
"earliest")
+      consumerConfig.setProperty(ConsumerConfig.CLIENT_RACK_CONFIG, 
server.config.rack.orNull)
+      consumerConfig.setProperty(ConsumerConfig.GROUP_INSTANCE_ID_CONFIG, 
s"instance-${server.config.brokerId}")
+      createConsumer()
+    }
+
+    val producer = createProducer()
+    val executor = Executors.newFixedThreadPool(consumers.size)
+
+    def verifyAssignments(assignments: List[Set[TopicPartition]]): Unit = {
+      val assignmentFutures = consumers.zipWithIndex.map { case (consumer, i) 
=>
+        executor.submit(() => {
+          val expectedAssignment = assignments(i)
+          TestUtils.pollUntilTrue(consumer, () => consumer.assignment() == 
expectedAssignment.asJava,
+            s"Timed out while awaiting expected assignment 
$expectedAssignment. The current assignment is ${consumer.assignment()}")
+        }, 0)
+      }
+      assignmentFutures.foreach(future => assertEquals(0, future.get(20, 
TimeUnit.SECONDS)))
+
+      assignments.flatten.foreach { tp =>
+        producer.send(new ProducerRecord(tp.topic, tp.partition, 
s"key-$tp".getBytes, s"value-$tp".getBytes))
+      }
+      consumers.zipWithIndex.foreach { case (consumer, i) =>
+        val records = TestUtils.pollUntilAtLeastNumRecords(consumer, 
assignments(i).size)
+        assertEquals(assignments(i), records.map(r => new 
TopicPartition(r.topic, r.partition)).toSet)
+      }
+    }
+
+    try {
+      // Rack-based assignment results in partitions assigned in reverse order 
since partition racks are in the reverse order.
+      
consumers.foreach(_.subscribe(Collections.singleton(topicWithSingleRackPartitions)))
+      verifyAssignments(partitionList.reverse.map(p => Set(new 
TopicPartition(topicWithSingleRackPartitions, p))))
+
+      // Non-rack-aware assignment results in ordered partitions.
+      
consumers.foreach(_.subscribe(Collections.singleton(topicWithAllPartitionsOnAllRacks)))
+      verifyAssignments(partitionList.map(p => Set(new 
TopicPartition(topicWithAllPartitionsOnAllRacks, p))))
+
+      // Rack-aware assignment with co-partitioning results in reverse 
assignment for both topics.
+      consumers.foreach(_.subscribe(Set(topicWithSingleRackPartitions, 
topicWithAllPartitionsOnAllRacks).asJava))
+      verifyAssignments(partitionList.reverse.map(p => Set(new 
TopicPartition(topicWithAllPartitionsOnAllRacks, p), new 
TopicPartition(topicWithSingleRackPartitions, p))))
+
+    } finally {
+      executor.shutdownNow()
+    }
+  }
+
   private def getPreferredReplica: Int = {
     val topicPartition = new TopicPartition(topic, 0)
     val offsetMap = Map(topicPartition -> 0L)

Reply via email to