This is an automated email from the ASF dual-hosted git repository.

guozhang pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 5ceaa588ee HOTFIX / KAFKA-14130: Reduce RackAwarenesssTest to unit 
Test (#12476)
5ceaa588ee is described below

commit 5ceaa588ee224fe1a859e4212636aba22b22f540
Author: Guozhang Wang <wangg...@gmail.com>
AuthorDate: Wed Aug 3 15:36:59 2022 -0700

    HOTFIX / KAFKA-14130: Reduce RackAwarenesssTest to unit Test (#12476)
    
    While working on KAFKA-13877, I feel it's an overkill to introduce the 
whole test class as an integration test, since all we need is to just test the 
assignor itself which could be a unit test. Running this suite with 9+ 
instances takes long time and is still vulnerable to all kinds of timing based 
flakiness. A better choice is to reduce it as a unit test, similar to 
HighAvailabilityStreamsPartitionAssignorTest that just test the behavior of the 
assignor itself, rather than creating m [...]
    
    Since we mock everything, there's no flakiness anymore. Plus we greatly 
reduced the test runtime (on my local machine, the old integration takes about 
35 secs to run the whole suite, while the new one take 20ms on average).
    
    Reviewers: Divij Vaidya <di...@amazon.com>, Dalibor Plavcic
---
 .../integration/RackAwarenessIntegrationTest.java  | 433 ----------------
 .../RackAwarenessStreamsPartitionAssignorTest.java | 576 +++++++++++++++++++++
 .../processor/internals/StreamTaskTest.java        |   4 +-
 .../internals/assignment/AssignmentTestUtils.java  |   3 +
 4 files changed, 581 insertions(+), 435 deletions(-)

diff --git 
a/streams/src/test/java/org/apache/kafka/streams/integration/RackAwarenessIntegrationTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/integration/RackAwarenessIntegrationTest.java
deleted file mode 100644
index 7c93b769f5..0000000000
--- 
a/streams/src/test/java/org/apache/kafka/streams/integration/RackAwarenessIntegrationTest.java
+++ /dev/null
@@ -1,433 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.kafka.streams.integration;
-
-import org.apache.kafka.common.serialization.Serdes;
-import org.apache.kafka.streams.KafkaStreams;
-import org.apache.kafka.streams.StreamsBuilder;
-import org.apache.kafka.streams.StreamsConfig;
-import org.apache.kafka.streams.ThreadMetadata;
-import org.apache.kafka.streams.Topology;
-import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
-import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
-import org.apache.kafka.streams.kstream.Consumed;
-import org.apache.kafka.streams.kstream.KStream;
-import org.apache.kafka.streams.kstream.Repartitioned;
-import org.apache.kafka.streams.processor.TaskId;
-import org.apache.kafka.streams.state.KeyValueStore;
-import org.apache.kafka.streams.state.StoreBuilder;
-import org.apache.kafka.streams.state.Stores;
-import org.apache.kafka.test.TestUtils;
-import org.junit.jupiter.api.AfterEach;
-import org.junit.jupiter.api.BeforeAll;
-import org.junit.jupiter.api.BeforeEach;
-import org.junit.jupiter.api.Tag;
-import org.junit.jupiter.api.Test;
-import org.junit.jupiter.api.TestInfo;
-import org.junit.jupiter.api.Timeout;
-
-import java.io.IOException;
-import java.time.Duration;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Properties;
-import java.util.function.Predicate;
-import java.util.stream.Collectors;
-import java.util.stream.Stream;
-
-import static java.util.Arrays.asList;
-import static java.util.Collections.singletonList;
-import static 
org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName;
-import static org.apache.kafka.test.TestUtils.waitForCondition;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-
-@Timeout(600)
-@Tag("integration")
-public class RackAwarenessIntegrationTest {
-    private static final int NUM_BROKERS = 1;
-
-    private static final EmbeddedKafkaCluster CLUSTER = new 
EmbeddedKafkaCluster(NUM_BROKERS);
-
-    private static final String TAG_VALUE_K8_CLUSTER_1 = "k8s-cluster-1";
-    private static final String TAG_VALUE_K8_CLUSTER_2 = "k8s-cluster-2";
-    private static final String TAG_VALUE_K8_CLUSTER_3 = "k8s-cluster-3";
-    private static final String TAG_VALUE_EU_CENTRAL_1A = "eu-central-1a";
-    private static final String TAG_VALUE_EU_CENTRAL_1B = "eu-central-1b";
-    private static final String TAG_VALUE_EU_CENTRAL_1C = "eu-central-1c";
-
-    private static final int DEFAULT_NUMBER_OF_STATEFUL_SUB_TOPOLOGIES = 1;
-    private static final int DEFAULT_NUMBER_OF_PARTITIONS_OF_SUB_TOPOLOGIES = 
2;
-
-    private static final String INPUT_TOPIC = "input-topic";
-
-    private static final String TAG_ZONE = "zone";
-    private static final String TAG_CLUSTER = "cluster";
-
-    private List<KafkaStreamsWithConfiguration> kafkaStreamsInstances;
-    private Properties baseConfiguration;
-    private Topology topology;
-
-    @BeforeAll
-    public static void createTopics() throws Exception {
-        CLUSTER.start();
-        CLUSTER.createTopic(INPUT_TOPIC, 6, 1);
-    }
-
-    @BeforeEach
-    public void setup(final TestInfo testInfo) {
-        kafkaStreamsInstances = new ArrayList<>();
-        baseConfiguration = new Properties();
-        final String safeTestName = safeUniqueTestName(getClass(), testInfo);
-        final String applicationId = "app-" + safeTestName;
-        baseConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, 
applicationId);
-        baseConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, 
CLUSTER.bootstrapServers());
-        baseConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, 
TestUtils.tempDirectory().getPath());
-        baseConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, 
Serdes.Integer().getClass());
-        baseConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, 
Serdes.Integer().getClass());
-    }
-
-    @AfterEach
-    public void cleanup() throws IOException {
-        for (final KafkaStreamsWithConfiguration kafkaStreamsWithConfiguration 
: kafkaStreamsInstances) {
-            
kafkaStreamsWithConfiguration.kafkaStreams.close(Duration.ofMillis(IntegrationTestUtils.DEFAULT_TIMEOUT));
-            
IntegrationTestUtils.purgeLocalStreamsState(kafkaStreamsWithConfiguration.configuration);
-        }
-        kafkaStreamsInstances.clear();
-    }
-
-    @Test
-    public void shouldDoRebalancingWithMaximumNumberOfClientTags() throws 
Exception {
-        initTopology(3, 3);
-        final int numberOfStandbyReplicas = 1;
-
-        final List<String> clientTagKeys = new ArrayList<>();
-        final Map<String, String> clientTags1 = new HashMap<>();
-        final Map<String, String> clientTags2 = new HashMap<>();
-
-        for (int i = 0; i < 
StreamsConfig.MAX_RACK_AWARE_ASSIGNMENT_TAG_LIST_SIZE; i++) {
-            clientTagKeys.add("key-" + i);
-        }
-
-        for (int i = 0; i < clientTagKeys.size(); i++) {
-            final String key = clientTagKeys.get(i);
-            clientTags1.put(key, "value-1-" + i);
-            clientTags2.put(key, "value-2-" + i);
-        }
-
-        assertEquals(StreamsConfig.MAX_RACK_AWARE_ASSIGNMENT_TAG_LIST_SIZE, 
clientTagKeys.size());
-        Stream.of(clientTags1, clientTags2)
-              .forEach(clientTags -> 
assertEquals(StreamsConfig.MAX_RACK_AWARE_ASSIGNMENT_TAG_LIST_SIZE,
-                                                  clientTags.size(),
-                                                  String.format("clientsTags 
with content '%s' " +
-                                                          "did not match 
expected size", clientTags)));
-
-        createAndStart(clientTags1, clientTagKeys, numberOfStandbyReplicas);
-        createAndStart(clientTags1, clientTagKeys, numberOfStandbyReplicas);
-        createAndStart(clientTags2, clientTagKeys, numberOfStandbyReplicas);
-
-        waitUntilAllKafkaStreamsClientsAreRunning();
-        waitForCondition(() -> 
isIdealTaskDistributionReachedForTags(clientTagKeys), "not all tags are evenly 
distributed");
-
-        stopKafkaStreamsInstanceWithIndex(0);
-
-        waitUntilAllKafkaStreamsClientsAreRunning();
-
-        waitForCondition(() -> 
isIdealTaskDistributionReachedForTags(clientTagKeys), "not all tags are evenly 
distributed");
-    }
-
-    @Test
-    public void 
shouldDistributeStandbyReplicasWhenAllClientsAreLocatedOnASameClusterTag() 
throws Exception {
-        initTopology();
-        final int numberOfStandbyReplicas = 1;
-        createAndStart(buildClientTags(TAG_VALUE_EU_CENTRAL_1A, 
TAG_VALUE_K8_CLUSTER_1), asList(TAG_ZONE, TAG_CLUSTER), 
numberOfStandbyReplicas);
-        createAndStart(buildClientTags(TAG_VALUE_EU_CENTRAL_1A, 
TAG_VALUE_K8_CLUSTER_1), asList(TAG_ZONE, TAG_CLUSTER), 
numberOfStandbyReplicas);
-        createAndStart(buildClientTags(TAG_VALUE_EU_CENTRAL_1A, 
TAG_VALUE_K8_CLUSTER_1), asList(TAG_ZONE, TAG_CLUSTER), 
numberOfStandbyReplicas);
-
-        createAndStart(buildClientTags(TAG_VALUE_EU_CENTRAL_1B, 
TAG_VALUE_K8_CLUSTER_1), asList(TAG_ZONE, TAG_CLUSTER), 
numberOfStandbyReplicas);
-        createAndStart(buildClientTags(TAG_VALUE_EU_CENTRAL_1B, 
TAG_VALUE_K8_CLUSTER_1), asList(TAG_ZONE, TAG_CLUSTER), 
numberOfStandbyReplicas);
-        createAndStart(buildClientTags(TAG_VALUE_EU_CENTRAL_1B, 
TAG_VALUE_K8_CLUSTER_1), asList(TAG_ZONE, TAG_CLUSTER), 
numberOfStandbyReplicas);
-
-        waitUntilAllKafkaStreamsClientsAreRunning();
-        waitForCondition(() -> 
isIdealTaskDistributionReachedForTags(singletonList(TAG_ZONE)), "not all tags 
are evenly distributed");
-    }
-
-    @Test
-    public void shouldDistributeStandbyReplicasOverMultipleClientTags() throws 
Exception {
-        initTopology();
-        final int numberOfStandbyReplicas = 2;
-
-        createAndStart(buildClientTags(TAG_VALUE_EU_CENTRAL_1A, 
TAG_VALUE_K8_CLUSTER_1), asList(TAG_ZONE, TAG_CLUSTER), 
numberOfStandbyReplicas);
-        createAndStart(buildClientTags(TAG_VALUE_EU_CENTRAL_1B, 
TAG_VALUE_K8_CLUSTER_1), asList(TAG_ZONE, TAG_CLUSTER), 
numberOfStandbyReplicas);
-        createAndStart(buildClientTags(TAG_VALUE_EU_CENTRAL_1C, 
TAG_VALUE_K8_CLUSTER_1), asList(TAG_ZONE, TAG_CLUSTER), 
numberOfStandbyReplicas);
-
-        createAndStart(buildClientTags(TAG_VALUE_EU_CENTRAL_1A, 
TAG_VALUE_K8_CLUSTER_2), asList(TAG_ZONE, TAG_CLUSTER), 
numberOfStandbyReplicas);
-        createAndStart(buildClientTags(TAG_VALUE_EU_CENTRAL_1B, 
TAG_VALUE_K8_CLUSTER_2), asList(TAG_ZONE, TAG_CLUSTER), 
numberOfStandbyReplicas);
-        createAndStart(buildClientTags(TAG_VALUE_EU_CENTRAL_1C, 
TAG_VALUE_K8_CLUSTER_2), asList(TAG_ZONE, TAG_CLUSTER), 
numberOfStandbyReplicas);
-
-        createAndStart(buildClientTags(TAG_VALUE_EU_CENTRAL_1A, 
TAG_VALUE_K8_CLUSTER_3), asList(TAG_ZONE, TAG_CLUSTER), 
numberOfStandbyReplicas);
-        createAndStart(buildClientTags(TAG_VALUE_EU_CENTRAL_1B, 
TAG_VALUE_K8_CLUSTER_3), asList(TAG_ZONE, TAG_CLUSTER), 
numberOfStandbyReplicas);
-        createAndStart(buildClientTags(TAG_VALUE_EU_CENTRAL_1C, 
TAG_VALUE_K8_CLUSTER_3), asList(TAG_ZONE, TAG_CLUSTER), 
numberOfStandbyReplicas);
-
-        waitUntilAllKafkaStreamsClientsAreRunning();
-        waitForCondition(() -> 
isIdealTaskDistributionReachedForTags(asList(TAG_ZONE, TAG_CLUSTER)), "not all 
tags are evenly distributed");
-    }
-
-    @Test
-    public void 
shouldDistributeStandbyReplicasWhenIdealDistributionCanNotBeAchieved() throws 
Exception {
-        initTopology();
-        final int numberOfStandbyReplicas = 2;
-
-        createAndStart(buildClientTags(TAG_VALUE_EU_CENTRAL_1A, 
TAG_VALUE_K8_CLUSTER_1), asList(TAG_ZONE, TAG_CLUSTER), 
numberOfStandbyReplicas);
-        createAndStart(buildClientTags(TAG_VALUE_EU_CENTRAL_1B, 
TAG_VALUE_K8_CLUSTER_1), asList(TAG_ZONE, TAG_CLUSTER), 
numberOfStandbyReplicas);
-        createAndStart(buildClientTags(TAG_VALUE_EU_CENTRAL_1C, 
TAG_VALUE_K8_CLUSTER_1), asList(TAG_ZONE, TAG_CLUSTER), 
numberOfStandbyReplicas);
-
-        createAndStart(buildClientTags(TAG_VALUE_EU_CENTRAL_1A, 
TAG_VALUE_K8_CLUSTER_2), asList(TAG_ZONE, TAG_CLUSTER), 
numberOfStandbyReplicas);
-        createAndStart(buildClientTags(TAG_VALUE_EU_CENTRAL_1B, 
TAG_VALUE_K8_CLUSTER_2), asList(TAG_ZONE, TAG_CLUSTER), 
numberOfStandbyReplicas);
-        createAndStart(buildClientTags(TAG_VALUE_EU_CENTRAL_1C, 
TAG_VALUE_K8_CLUSTER_2), asList(TAG_ZONE, TAG_CLUSTER), 
numberOfStandbyReplicas);
-
-        waitUntilAllKafkaStreamsClientsAreRunning();
-
-        waitForCondition(() -> 
isIdealTaskDistributionReachedForTags(singletonList(TAG_ZONE)), "not all tags 
are evenly distributed");
-        waitForCondition(() -> 
isPartialTaskDistributionReachedForTags(singletonList(TAG_CLUSTER)), "not all 
tags are evenly distributed");
-    }
-
-    private void stopKafkaStreamsInstanceWithIndex(final int index) {
-        
kafkaStreamsInstances.get(index).kafkaStreams.close(Duration.ofMillis(IntegrationTestUtils.DEFAULT_TIMEOUT));
-        kafkaStreamsInstances.remove(index);
-    }
-
-    private void waitUntilAllKafkaStreamsClientsAreRunning() throws Exception {
-        
waitUntilAllKafkaStreamsClientsAreRunning(Duration.ofMillis(IntegrationTestUtils.DEFAULT_TIMEOUT));
-    }
-
-    private void waitUntilAllKafkaStreamsClientsAreRunning(final Duration 
timeout) throws Exception {
-        
IntegrationTestUtils.waitForApplicationState(kafkaStreamsInstances.stream().map(it
 -> it.kafkaStreams).collect(Collectors.toList()),
-                                                     
KafkaStreams.State.RUNNING,
-                                                     timeout);
-    }
-
-    private boolean isPartialTaskDistributionReachedForTags(final 
Collection<String> tagsToCheck) {
-        final Predicate<TaskClientTagDistribution> 
partialTaskClientTagDistributionTest = taskClientTagDistribution -> {
-            final Map<String, String> activeTaskClientTags = 
taskClientTagDistribution.activeTaskClientTags.clientTags;
-            return 
tagsAmongstActiveAndAtLeastOneStandbyTaskIsDifferent(taskClientTagDistribution.standbyTasksClientTags,
 activeTaskClientTags, tagsToCheck);
-        };
-
-        return 
isTaskDistributionTestSuccessful(partialTaskClientTagDistributionTest);
-    }
-
-    private boolean isIdealTaskDistributionReachedForTags(final 
Collection<String> tagsToCheck) {
-        final Predicate<TaskClientTagDistribution> 
idealTaskClientTagDistributionTest = taskClientTagDistribution -> {
-            final Map<String, String> activeTaskClientTags = 
taskClientTagDistribution.activeTaskClientTags.clientTags;
-            return 
tagsAmongstStandbyTasksAreDifferent(taskClientTagDistribution.standbyTasksClientTags,
 tagsToCheck)
-                   && 
tagsAmongstActiveAndAllStandbyTasksAreDifferent(taskClientTagDistribution.standbyTasksClientTags,
-                                                                      
activeTaskClientTags,
-                                                                      
tagsToCheck);
-        };
-
-        return 
isTaskDistributionTestSuccessful(idealTaskClientTagDistributionTest);
-    }
-
-    private boolean isTaskDistributionTestSuccessful(final 
Predicate<TaskClientTagDistribution> taskClientTagDistributionPredicate) {
-        final List<TaskClientTagDistribution> tasksClientTagDistributions = 
getTasksClientTagDistributions();
-
-        if (tasksClientTagDistributions.isEmpty()) {
-            return false;
-        }
-
-        return 
tasksClientTagDistributions.stream().allMatch(taskClientTagDistributionPredicate);
-    }
-
-    private static boolean 
tagsAmongstActiveAndAllStandbyTasksAreDifferent(final 
Collection<TaskClientTags> standbyTasks,
-                                                                           
final Map<String, String> activeTaskClientTags,
-                                                                           
final Collection<String> tagsToCheck) {
-        return standbyTasks.stream().allMatch(standbyTask -> 
tagsToCheck.stream().noneMatch(tag -> 
activeTaskClientTags.get(tag).equals(standbyTask.clientTags.get(tag))));
-    }
-
-    private static boolean 
tagsAmongstActiveAndAtLeastOneStandbyTaskIsDifferent(final 
Collection<TaskClientTags> standbyTasks,
-                                                                               
 final Map<String, String> activeTaskClientTags,
-                                                                               
 final Collection<String> tagsToCheck) {
-        return standbyTasks.stream().anyMatch(standbyTask -> 
tagsToCheck.stream().noneMatch(tag -> 
activeTaskClientTags.get(tag).equals(standbyTask.clientTags.get(tag))));
-    }
-
-    private static boolean tagsAmongstStandbyTasksAreDifferent(final 
Collection<TaskClientTags> standbyTasks, final Collection<String> tagsToCheck) {
-        final Map<String, Integer> statistics = new HashMap<>();
-
-        for (final TaskClientTags standbyTask : standbyTasks) {
-            for (final String tag : tagsToCheck) {
-                final String tagValue = standbyTask.clientTags.get(tag);
-                final Integer tagValueOccurrence = 
statistics.getOrDefault(tagValue, 0);
-                statistics.put(tagValue, tagValueOccurrence + 1);
-            }
-        }
-
-        return statistics.values().stream().noneMatch(occurrence -> occurrence 
> 1);
-    }
-
-    private void initTopology() {
-        initTopology(DEFAULT_NUMBER_OF_PARTITIONS_OF_SUB_TOPOLOGIES, 
DEFAULT_NUMBER_OF_STATEFUL_SUB_TOPOLOGIES);
-    }
-
-    private void initTopology(final int numberOfPartitionsOfSubTopologies, 
final int numberOfStatefulSubTopologies) {
-        final StreamsBuilder builder = new StreamsBuilder();
-        final String stateStoreName = "myTransformState";
-
-        final StoreBuilder<KeyValueStore<Integer, Integer>> 
keyValueStoreBuilder = Stores.keyValueStoreBuilder(
-            Stores.persistentKeyValueStore(stateStoreName),
-            Serdes.Integer(),
-            Serdes.Integer()
-        );
-
-        builder.addStateStore(keyValueStoreBuilder);
-
-        final KStream<Integer, Integer> stream = builder.stream(INPUT_TOPIC, 
Consumed.with(Serdes.Integer(), Serdes.Integer()));
-
-        // Stateless sub-topology
-        
stream.repartition(Repartitioned.numberOfPartitions(numberOfPartitionsOfSubTopologies)).filter((k,
 v) -> true);
-
-        // Stateful sub-topologies
-        for (int i = 0; i < numberOfStatefulSubTopologies; i++) {
-            
stream.repartition(Repartitioned.numberOfPartitions(numberOfPartitionsOfSubTopologies))
-                  .groupByKey()
-                  .reduce(Integer::sum);
-        }
-
-        topology = builder.build();
-    }
-
-    private List<TaskClientTagDistribution> getTasksClientTagDistributions() {
-        final List<TaskClientTagDistribution> taskClientTags = new 
ArrayList<>();
-
-        for (final KafkaStreamsWithConfiguration kafkaStreamsInstance : 
kafkaStreamsInstances) {
-            final StreamsConfig config = new 
StreamsConfig(kafkaStreamsInstance.configuration);
-            for (final ThreadMetadata localThreadsMetadata : 
kafkaStreamsInstance.kafkaStreams.metadataForLocalThreads()) {
-                localThreadsMetadata.activeTasks().forEach(activeTask -> {
-                    final TaskId activeTaskId = activeTask.taskId();
-                    final Map<String, String> clientTags = 
config.getClientTags();
-
-                    final List<TaskClientTags> standbyTasks = 
findStandbysForActiveTask(activeTaskId);
-
-                    if (!standbyTasks.isEmpty()) {
-                        final TaskClientTags activeTaskView = new 
TaskClientTags(activeTaskId, clientTags);
-                        taskClientTags.add(new 
TaskClientTagDistribution(activeTaskView, standbyTasks));
-                    }
-                });
-
-            }
-        }
-
-        return taskClientTags;
-    }
-
-    private List<TaskClientTags> findStandbysForActiveTask(final TaskId 
taskId) {
-        final List<TaskClientTags> standbyTasks = new ArrayList<>();
-
-        for (final KafkaStreamsWithConfiguration kafkaStreamsInstance : 
kafkaStreamsInstances) {
-            for (final ThreadMetadata localThreadsMetadata : 
kafkaStreamsInstance.kafkaStreams.metadataForLocalThreads()) {
-                localThreadsMetadata.standbyTasks().forEach(standbyTask -> {
-                    final TaskId standbyTaskId = standbyTask.taskId();
-                    if (taskId.equals(standbyTaskId)) {
-                        final StreamsConfig config = new 
StreamsConfig(kafkaStreamsInstance.configuration);
-                        standbyTasks.add(new TaskClientTags(standbyTaskId, 
config.getClientTags()));
-                    }
-                });
-            }
-        }
-
-        return standbyTasks;
-    }
-
-    private static Map<String, String> buildClientTags(final String zone, 
final String cluster) {
-        final Map<String, String> clientTags = new HashMap<>();
-
-        clientTags.put(TAG_ZONE, zone);
-        clientTags.put(TAG_CLUSTER, cluster);
-
-        return clientTags;
-    }
-
-    private void createAndStart(final Map<String, String> clientTags,
-                                final Collection<String> 
rackAwareAssignmentTags,
-                                final int numberOfStandbyReplicas) {
-        final Properties streamsConfiguration = 
createStreamsConfiguration(clientTags, rackAwareAssignmentTags, 
numberOfStandbyReplicas);
-        final KafkaStreams kafkaStreams = new KafkaStreams(topology, 
streamsConfiguration);
-
-        kafkaStreamsInstances.add(new 
KafkaStreamsWithConfiguration(streamsConfiguration, kafkaStreams));
-
-        kafkaStreams.start();
-    }
-
-    private Properties createStreamsConfiguration(final Map<String, String> 
clientTags,
-                                                  final Collection<String> 
rackAwareAssignmentTags,
-                                                  final int 
numStandbyReplicas) {
-        final Properties streamsConfiguration = new Properties();
-        streamsConfiguration.putAll(baseConfiguration);
-        streamsConfiguration.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 
numStandbyReplicas);
-        
streamsConfiguration.put(StreamsConfig.RACK_AWARE_ASSIGNMENT_TAGS_CONFIG, 
String.join(",", rackAwareAssignmentTags));
-        clientTags.forEach((key, value) -> 
streamsConfiguration.put(StreamsConfig.clientTagPrefix(key), value));
-        streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, 
TestUtils.tempDirectory(String.join("-", clientTags.values())).getPath());
-        return streamsConfiguration;
-    }
-
-    private static final class KafkaStreamsWithConfiguration {
-        private final Properties configuration;
-        private final KafkaStreams kafkaStreams;
-
-        KafkaStreamsWithConfiguration(final Properties configuration, final 
KafkaStreams kafkaStreams) {
-            this.configuration = configuration;
-            this.kafkaStreams = kafkaStreams;
-        }
-    }
-
-    private static final class TaskClientTagDistribution {
-        private final TaskClientTags activeTaskClientTags;
-        private final List<TaskClientTags> standbyTasksClientTags;
-
-        TaskClientTagDistribution(final TaskClientTags activeTaskClientTags, 
final List<TaskClientTags> standbyTasksClientTags) {
-            this.activeTaskClientTags = activeTaskClientTags;
-            this.standbyTasksClientTags = standbyTasksClientTags;
-        }
-
-        @Override
-        public String toString() {
-            return "TaskDistribution{" +
-                   "activeTaskClientTagsView=" + activeTaskClientTags +
-                   ", standbyTasks=" + standbyTasksClientTags +
-                   '}';
-        }
-    }
-
-    private static final class TaskClientTags {
-        private final TaskId taskId;
-        private final Map<String, String> clientTags;
-
-        TaskClientTags(final TaskId taskId, final Map<String, String> 
clientTags) {
-            this.taskId = taskId;
-            this.clientTags = clientTags;
-        }
-
-        @Override
-        public String toString() {
-            return "TaskClientTags{" +
-                   "taskId=" + taskId +
-                   ", clientTags=" + clientTags +
-                   '}';
-        }
-    }
-}
\ No newline at end of file
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RackAwarenessStreamsPartitionAssignorTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RackAwarenessStreamsPartitionAssignorTest.java
new file mode 100644
index 0000000000..12a14c2dc5
--- /dev/null
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RackAwarenessStreamsPartitionAssignorTest.java
@@ -0,0 +1,576 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.clients.admin.Admin;
+import org.apache.kafka.clients.admin.AdminClient;
+import org.apache.kafka.clients.admin.ListOffsetsResult;
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor;
+import org.apache.kafka.common.Cluster;
+import org.apache.kafka.common.Node;
+import org.apache.kafka.common.PartitionInfo;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.internals.KafkaFutureImpl;
+import org.apache.kafka.common.utils.MockTime;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.assignment.AssignmentInfo;
+import 
org.apache.kafka.streams.processor.internals.assignment.ReferenceContainer;
+import 
org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo;
+import org.apache.kafka.test.MockApiProcessorSupplier;
+import org.apache.kafka.test.MockClientSupplier;
+import org.apache.kafka.test.MockInternalTopicManager;
+import org.apache.kafka.test.MockKeyValueStoreBuilder;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.junit.MockitoJUnitRunner;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+import java.util.stream.Collectors;
+
+import static java.util.Arrays.asList;
+import static java.util.Collections.emptySet;
+import static java.util.Collections.singletonList;
+import static org.apache.kafka.common.utils.Utils.mkEntry;
+import static org.apache.kafka.common.utils.Utils.mkMap;
+import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_CHANGELOG_END_OFFSETS;
+import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_TASKS;
+import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_1;
+import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_2;
+import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_3;
+import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_4;
+import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_5;
+import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_6;
+import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_7;
+import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_8;
+import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_9;
+import static 
org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.LATEST_SUPPORTED_VERSION;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+@RunWith(MockitoJUnitRunner.StrictStubs.class)
+public class RackAwarenessStreamsPartitionAssignorTest {
+
+    private final List<PartitionInfo> infos = asList(
+        new PartitionInfo("topic0", 0, Node.noNode(), new Node[0], new 
Node[0]),
+        new PartitionInfo("topic0", 1, Node.noNode(), new Node[0], new 
Node[0]),
+        new PartitionInfo("topic0", 2, Node.noNode(), new Node[0], new 
Node[0]),
+        new PartitionInfo("topic1", 0, Node.noNode(), new Node[0], new 
Node[0]),
+        new PartitionInfo("topic1", 1, Node.noNode(), new Node[0], new 
Node[0]),
+        new PartitionInfo("topic1", 2, Node.noNode(), new Node[0], new 
Node[0]),
+        new PartitionInfo("topic2", 0, Node.noNode(), new Node[0], new 
Node[0]),
+        new PartitionInfo("topic2", 1, Node.noNode(), new Node[0], new 
Node[0]),
+        new PartitionInfo("topic2", 2, Node.noNode(), new Node[0], new 
Node[0]),
+        new PartitionInfo("topic3", 0, Node.noNode(), new Node[0], new 
Node[0]),
+        new PartitionInfo("topic3", 1, Node.noNode(), new Node[0], new 
Node[0]),
+        new PartitionInfo("topic3", 2, Node.noNode(), new Node[0], new 
Node[0]),
+        new PartitionInfo("topic4", 0, Node.noNode(), new Node[0], new 
Node[0]),
+        new PartitionInfo("topic4", 1, Node.noNode(), new Node[0], new 
Node[0]),
+        new PartitionInfo("topic4", 2, Node.noNode(), new Node[0], new Node[0])
+    );
+
+    final String consumer1 = "consumer1";
+    final String consumer2 = "consumer2";
+    final String consumer3 = "consumer3";
+    final String consumer4 = "consumer4";
+    final String consumer5 = "consumer5";
+    final String consumer6 = "consumer6";
+    final String consumer7 = "consumer7";
+    final String consumer8 = "consumer8";
+    final String consumer9 = "consumer9";
+
+
+    private final Cluster metadata = new Cluster(
+            "cluster",
+            singletonList(Node.noNode()),
+            infos,
+            emptySet(),
+            emptySet());
+
+    private final static List<String> ALL_TAG_KEYS = new ArrayList<>();
+    static {
+        for (int i = 0; i < 
StreamsConfig.MAX_RACK_AWARE_ASSIGNMENT_TAG_LIST_SIZE; i++) {
+            ALL_TAG_KEYS.add("key-" + i);
+        }
+    }
+
+    private final StreamsPartitionAssignor partitionAssignor = new 
StreamsPartitionAssignor();
+    private final MockClientSupplier mockClientSupplier = new 
MockClientSupplier();
+    private static final String USER_END_POINT = "localhost:8080";
+    private static final String APPLICATION_ID = 
"stream-partition-assignor-test";
+
+    private TaskManager taskManager;
+    private Admin adminClient;
+    private StreamsConfig streamsConfig = new StreamsConfig(configProps());
+    private final InternalTopologyBuilder builder = new 
InternalTopologyBuilder();
+    private TopologyMetadata topologyMetadata = new TopologyMetadata(builder, 
streamsConfig);
+    private final StreamsMetadataState streamsMetadataState = 
mock(StreamsMetadataState.class);
+    private final Map<String, ConsumerPartitionAssignor.Subscription> 
subscriptions = new HashMap<>();
+    private final MockTime time = new MockTime();
+
+    @SuppressWarnings("unchecked")
+    private Map<String, Object> configProps() {
+        final Map<String, Object> configurationMap = new HashMap<>();
+        configurationMap.put(StreamsConfig.APPLICATION_ID_CONFIG, 
APPLICATION_ID);
+        configurationMap.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, 
USER_END_POINT);
+        final ReferenceContainer referenceContainer = new ReferenceContainer();
+        referenceContainer.mainConsumer = (Consumer<byte[], byte[]>) 
mock(Consumer.class);
+        referenceContainer.adminClient = adminClient;
+        referenceContainer.taskManager = taskManager;
+        referenceContainer.streamsMetadataState = streamsMetadataState;
+        referenceContainer.time = time;
+        
configurationMap.put(StreamsConfig.InternalConfig.REFERENCE_CONTAINER_PARTITION_ASSIGNOR,
 referenceContainer);
+        configurationMap.put(StreamsConfig.RACK_AWARE_ASSIGNMENT_TAGS_CONFIG, 
String.join(",", ALL_TAG_KEYS));
+        ALL_TAG_KEYS.forEach(key -> 
configurationMap.put(StreamsConfig.clientTagPrefix(key), "dummy"));
+        return configurationMap;
+    }
+
+    // Make sure to complete setting up any mocks (such as TaskManager or 
AdminClient) before configuring the assignor
+    private void configurePartitionAssignorWith(final Map<String, Object> 
props) {
+        final Map<String, Object> configMap = configProps();
+        configMap.putAll(props);
+
+        streamsConfig = new StreamsConfig(configMap);
+        topologyMetadata = new TopologyMetadata(builder, streamsConfig);
+        partitionAssignor.configure(configMap);
+
+        overwriteInternalTopicManagerWithMock();
+    }
+
+    // Useful for tests that don't care about the task offset sums
+    private void createMockTaskManager() {
+        taskManager = mock(TaskManager.class);
+        when(taskManager.topologyMetadata()).thenReturn(topologyMetadata);
+        when(taskManager.processId()).thenReturn(UUID_1);
+        topologyMetadata.buildAndRewriteTopology();
+    }
+
+    // If you don't care about setting the end offsets for each specific topic 
partition, the helper method
+    // getTopicPartitionOffsetMap is useful for building this input map for 
all partitions
+    private void createMockAdminClient(final Map<TopicPartition, Long> 
changelogEndOffsets) {
+        adminClient = mock(AdminClient.class);
+
+        final ListOffsetsResult result = mock(ListOffsetsResult.class);
+        final KafkaFutureImpl<Map<TopicPartition, 
ListOffsetsResult.ListOffsetsResultInfo>> allFuture = new KafkaFutureImpl<>();
+        
allFuture.complete(changelogEndOffsets.entrySet().stream().collect(Collectors.toMap(
+                Map.Entry::getKey,
+                t -> {
+                    final ListOffsetsResult.ListOffsetsResultInfo info = 
mock(ListOffsetsResult.ListOffsetsResultInfo.class);
+                    when(info.offset()).thenReturn(t.getValue());
+                    return info;
+                }))
+        );
+
+        when(adminClient.listOffsets(any())).thenReturn(result);
+        when(result.all()).thenReturn(allFuture);
+    }
+
+    private void overwriteInternalTopicManagerWithMock() {
+        final MockInternalTopicManager mockInternalTopicManager = new 
MockInternalTopicManager(
+                time,
+                streamsConfig,
+                mockClientSupplier.restoreConsumer,
+                false
+        );
+        partitionAssignor.setInternalTopicManager(mockInternalTopicManager);
+    }
+
+    @Before
+    public void setUp() {
+        createMockAdminClient(EMPTY_CHANGELOG_END_OFFSETS);
+    }
+
+    @Test
+    public void shouldDistributeWithMaximumNumberOfClientTags() {
+        setupTopology(3, 2);
+
+        createMockTaskManager();
+        createMockAdminClient(getTopicPartitionOffsetsMap(
+            Arrays.asList(APPLICATION_ID + "-store2-changelog", APPLICATION_ID 
+ "-store3-changelog", APPLICATION_ID + "-store4-changelog"),
+            Arrays.asList(3, 3, 3)));
+        
configurePartitionAssignorWith(Collections.singletonMap(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG,
 1));
+
+        final Map<String, String> clientTags1 = new HashMap<>();
+        final Map<String, String> clientTags2 = new HashMap<>();
+
+        for (int i = 0; i < ALL_TAG_KEYS.size(); i++) {
+            final String key = ALL_TAG_KEYS.get(i);
+            clientTags1.put(key, "value-1-" + i);
+            clientTags2.put(key, "value-2-" + i);
+        }
+
+        final Map<String, Map<String, String>> hostTags = new HashMap<>();
+        subscriptions.put(consumer1, getSubscription(UUID_1, EMPTY_TASKS, 
clientTags1));
+        hostTags.put(consumer1, clientTags1);
+        subscriptions.put(consumer2, getSubscription(UUID_2, EMPTY_TASKS, 
clientTags1));
+        hostTags.put(consumer2, clientTags1);
+        subscriptions.put(consumer3, getSubscription(UUID_3, EMPTY_TASKS, 
clientTags2));
+        hostTags.put(consumer3, clientTags2);
+
+        Map<String, ConsumerPartitionAssignor.Assignment> assignments = 
partitionAssignor
+            .assign(metadata, new 
ConsumerPartitionAssignor.GroupSubscription(subscriptions))
+            .groupAssignment();
+
+        
verifyIdealTaskDistributionReached(getClientTagDistributions(assignments, 
hostTags), ALL_TAG_KEYS);
+
+        // kill the first consumer and rebalance, should still achieve ideal 
distribution
+        subscriptions.clear();
+        subscriptions.put(consumer2, getSubscription(UUID_2, 
AssignmentInfo.decode(assignments.get(consumer2).userData()).activeTasks(), 
clientTags1));
+        subscriptions.put(consumer3, getSubscription(UUID_3, 
AssignmentInfo.decode(assignments.get(consumer3).userData()).activeTasks(), 
clientTags2));
+
+        assignments = partitionAssignor.assign(metadata, new 
ConsumerPartitionAssignor.GroupSubscription(subscriptions))
+            .groupAssignment();
+
+        
verifyIdealTaskDistributionReached(getClientTagDistributions(assignments, 
hostTags), ALL_TAG_KEYS);
+    }
+
+    @Test
+    public void shouldDistributeOnDistinguishingTagSubset() {
+        setupTopology(3, 0);
+
+        createMockTaskManager();
+        createMockAdminClient(getTopicPartitionOffsetsMap(
+            Arrays.asList(APPLICATION_ID + "-store0-changelog", APPLICATION_ID 
+ "-store1-changelog", APPLICATION_ID + "-store2-changelog"),
+            Arrays.asList(3, 3, 3)));
+        
configurePartitionAssignorWith(Collections.singletonMap(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG,
 1));
+
+        // use the same tag value for key1, and different value for key2
+        // then we verify that for key2 we still achieve ideal distribution
+        final Map<String, String> clientTags1 = new HashMap<>();
+        final Map<String, String> clientTags2 = new HashMap<>();
+        clientTags1.put(ALL_TAG_KEYS.get(0), "value-1-all");
+        clientTags2.put(ALL_TAG_KEYS.get(0), "value-2-all");
+        clientTags1.put(ALL_TAG_KEYS.get(1), "value-1-1");
+        clientTags2.put(ALL_TAG_KEYS.get(1), "value-2-2");
+
+        final String consumer1 = "consumer1";
+        final String consumer2 = "consumer2";
+        final String consumer3 = "consumer3";
+        final String consumer4 = "consumer4";
+        final String consumer5 = "consumer5";
+        final String consumer6 = "consumer6";
+
+        final Map<String, Map<String, String>> hostTags = new HashMap<>();
+        subscriptions.put(consumer1, getSubscription(UUID_1, EMPTY_TASKS, 
clientTags1));
+        hostTags.put(consumer1, clientTags1);
+        subscriptions.put(consumer2, getSubscription(UUID_2, EMPTY_TASKS, 
clientTags1));
+        hostTags.put(consumer2, clientTags1);
+        subscriptions.put(consumer3, getSubscription(UUID_3, EMPTY_TASKS, 
clientTags1));
+        hostTags.put(consumer3, clientTags1);
+        subscriptions.put(consumer4, getSubscription(UUID_4, EMPTY_TASKS, 
clientTags2));
+        hostTags.put(consumer4, clientTags2);
+        subscriptions.put(consumer5, getSubscription(UUID_5, EMPTY_TASKS, 
clientTags2));
+        hostTags.put(consumer5, clientTags2);
+        subscriptions.put(consumer6, getSubscription(UUID_6, EMPTY_TASKS, 
clientTags2));
+        hostTags.put(consumer6, clientTags2);
+
+        final Map<String, ConsumerPartitionAssignor.Assignment> assignments = 
partitionAssignor
+            .assign(metadata, new 
ConsumerPartitionAssignor.GroupSubscription(subscriptions))
+            .groupAssignment();
+
+        
verifyIdealTaskDistributionReached(getClientTagDistributions(assignments, 
hostTags), Collections.singletonList(ALL_TAG_KEYS.get(1)));
+    }
+
+    @Test
+    public void shouldDistributeWithMultipleStandbys() {
+        setupTopology(3, 0);
+
+        createMockTaskManager();
+        createMockAdminClient(getTopicPartitionOffsetsMap(
+            Arrays.asList(APPLICATION_ID + "-store0-changelog", APPLICATION_ID 
+ "-store1-changelog", APPLICATION_ID + "-store2-changelog"),
+            Arrays.asList(3, 3, 3)));
+        
configurePartitionAssignorWith(Collections.singletonMap(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG,
 2));
+
+        final Map<String, String> clientTags1 = mkMap(
+            mkEntry(ALL_TAG_KEYS.get(0), "value-0-1"),
+            mkEntry(ALL_TAG_KEYS.get(1), "value-1-1"));
+        final Map<String, String> clientTags2 = mkMap(
+            mkEntry(ALL_TAG_KEYS.get(0), "value-0-1"),
+            mkEntry(ALL_TAG_KEYS.get(1), "value-1-2"));
+        final Map<String, String> clientTags3 = mkMap(
+            mkEntry(ALL_TAG_KEYS.get(0), "value-0-1"),
+            mkEntry(ALL_TAG_KEYS.get(1), "value-1-3"));
+        final Map<String, String> clientTags4 = mkMap(
+            mkEntry(ALL_TAG_KEYS.get(0), "value-0-2"),
+            mkEntry(ALL_TAG_KEYS.get(1), "value-1-1"));
+        final Map<String, String> clientTags5 = mkMap(
+            mkEntry(ALL_TAG_KEYS.get(0), "value-0-2"),
+            mkEntry(ALL_TAG_KEYS.get(1), "value-1-2"));
+        final Map<String, String> clientTags6 = mkMap(
+            mkEntry(ALL_TAG_KEYS.get(0), "value-0-2"),
+            mkEntry(ALL_TAG_KEYS.get(1), "value-1-3"));
+        final Map<String, String> clientTags7 = mkMap(
+            mkEntry(ALL_TAG_KEYS.get(0), "value-0-3"),
+            mkEntry(ALL_TAG_KEYS.get(1), "value-1-1"));
+        final Map<String, String> clientTags8 = mkMap(
+            mkEntry(ALL_TAG_KEYS.get(0), "value-0-3"),
+            mkEntry(ALL_TAG_KEYS.get(1), "value-1-2"));
+        final Map<String, String> clientTags9 = mkMap(
+            mkEntry(ALL_TAG_KEYS.get(0), "value-0-3"),
+            mkEntry(ALL_TAG_KEYS.get(1), "value-1-3"));
+
+        final Map<String, Map<String, String>> hostTags = new HashMap<>();
+        subscriptions.put(consumer1, getSubscription(UUID_1, EMPTY_TASKS, 
clientTags1));
+        hostTags.put(consumer1, clientTags1);
+        subscriptions.put(consumer2, getSubscription(UUID_2, EMPTY_TASKS, 
clientTags2));
+        hostTags.put(consumer2, clientTags2);
+        subscriptions.put(consumer3, getSubscription(UUID_3, EMPTY_TASKS, 
clientTags3));
+        hostTags.put(consumer3, clientTags3);
+        subscriptions.put(consumer4, getSubscription(UUID_4, EMPTY_TASKS, 
clientTags4));
+        hostTags.put(consumer4, clientTags4);
+        subscriptions.put(consumer5, getSubscription(UUID_5, EMPTY_TASKS, 
clientTags5));
+        hostTags.put(consumer5, clientTags5);
+        subscriptions.put(consumer6, getSubscription(UUID_6, EMPTY_TASKS, 
clientTags6));
+        hostTags.put(consumer6, clientTags6);
+        subscriptions.put(consumer7, getSubscription(UUID_7, EMPTY_TASKS, 
clientTags7));
+        hostTags.put(consumer7, clientTags7);
+        subscriptions.put(consumer8, getSubscription(UUID_8, EMPTY_TASKS, 
clientTags8));
+        hostTags.put(consumer8, clientTags8);
+        subscriptions.put(consumer9, getSubscription(UUID_9, EMPTY_TASKS, 
clientTags9));
+        hostTags.put(consumer9, clientTags9);
+
+        final Map<String, ConsumerPartitionAssignor.Assignment> assignments = 
partitionAssignor
+            .assign(metadata, new 
ConsumerPartitionAssignor.GroupSubscription(subscriptions))
+            .groupAssignment();
+
+        
verifyIdealTaskDistributionReached(getClientTagDistributions(assignments, 
hostTags), Arrays.asList(ALL_TAG_KEYS.get(0), ALL_TAG_KEYS.get(1)));
+    }
+
+    @Test
+    public void shouldDistributePartiallyWhenDoNotHaveEnoughClients() {
+        setupTopology(3, 0);
+
+        createMockTaskManager();
+        createMockAdminClient(getTopicPartitionOffsetsMap(
+            Arrays.asList(APPLICATION_ID + "-store0-changelog", APPLICATION_ID 
+ "-store1-changelog", APPLICATION_ID + "-store2-changelog"),
+            Arrays.asList(3, 3, 3)));
+        
configurePartitionAssignorWith(Collections.singletonMap(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG,
 2));
+
+        final Map<String, String> clientTags1 = mkMap(
+            mkEntry(ALL_TAG_KEYS.get(0), "value-0-1"),
+            mkEntry(ALL_TAG_KEYS.get(1), "value-1-1"));
+        final Map<String, String> clientTags2 = mkMap(
+            mkEntry(ALL_TAG_KEYS.get(0), "value-0-1"),
+            mkEntry(ALL_TAG_KEYS.get(1), "value-1-2"));
+        final Map<String, String> clientTags3 = mkMap(
+            mkEntry(ALL_TAG_KEYS.get(0), "value-0-1"),
+            mkEntry(ALL_TAG_KEYS.get(1), "value-1-3"));
+        final Map<String, String> clientTags4 = mkMap(
+            mkEntry(ALL_TAG_KEYS.get(0), "value-0-2"),
+            mkEntry(ALL_TAG_KEYS.get(1), "value-1-1"));
+        final Map<String, String> clientTags5 = mkMap(
+            mkEntry(ALL_TAG_KEYS.get(0), "value-0-2"),
+            mkEntry(ALL_TAG_KEYS.get(1), "value-1-2"));
+        final Map<String, String> clientTags6 = mkMap(
+            mkEntry(ALL_TAG_KEYS.get(0), "value-0-2"),
+            mkEntry(ALL_TAG_KEYS.get(1), "value-1-3"));
+
+        final Map<String, Map<String, String>> hostTags = new HashMap<>();
+        subscriptions.put(consumer1, getSubscription(UUID_1, EMPTY_TASKS, 
clientTags1));
+        hostTags.put(consumer1, clientTags1);
+        subscriptions.put(consumer2, getSubscription(UUID_2, EMPTY_TASKS, 
clientTags2));
+        hostTags.put(consumer2, clientTags2);
+        subscriptions.put(consumer3, getSubscription(UUID_3, EMPTY_TASKS, 
clientTags3));
+        hostTags.put(consumer3, clientTags3);
+        subscriptions.put(consumer4, getSubscription(UUID_4, EMPTY_TASKS, 
clientTags4));
+        hostTags.put(consumer4, clientTags4);
+        subscriptions.put(consumer5, getSubscription(UUID_5, EMPTY_TASKS, 
clientTags5));
+        hostTags.put(consumer5, clientTags5);
+        subscriptions.put(consumer6, getSubscription(UUID_6, EMPTY_TASKS, 
clientTags6));
+        hostTags.put(consumer6, clientTags6);
+
+        final Map<String, ConsumerPartitionAssignor.Assignment> assignments = 
partitionAssignor
+            .assign(metadata, new 
ConsumerPartitionAssignor.GroupSubscription(subscriptions))
+            .groupAssignment();
+
+        
verifyIdealTaskDistributionReached(getClientTagDistributions(assignments, 
hostTags), Collections.singletonList(ALL_TAG_KEYS.get(1)));
+        
verifyPartialTaskDistributionReached(getClientTagDistributions(assignments, 
hostTags), Collections.singletonList(ALL_TAG_KEYS.get(0)));
+    }
+
+    private Map<TaskId, ClientTagDistribution> getClientTagDistributions(final 
Map<String, ConsumerPartitionAssignor.Assignment> assignments,
+                                                                         final 
Map<String, Map<String, String>> hostTags) {
+        final Map<TaskId, ClientTagDistribution> taskClientTags = new 
HashMap<>();
+
+        for (final Map.Entry<String, ConsumerPartitionAssignor.Assignment> 
entry : assignments.entrySet()) {
+            final AssignmentInfo info = 
AssignmentInfo.decode(entry.getValue().userData());
+
+            for (final TaskId activeTaskId : info.activeTasks()) {
+                taskClientTags.putIfAbsent(activeTaskId, new 
ClientTagDistribution(activeTaskId));
+                final ClientTagDistribution tagDistribution = 
taskClientTags.get(activeTaskId);
+                tagDistribution.addActiveTags(hostTags.get(entry.getKey()));
+            }
+
+            for (final TaskId standbyTaskId : info.standbyTasks().keySet()) {
+                taskClientTags.putIfAbsent(standbyTaskId, new 
ClientTagDistribution(standbyTaskId));
+                final ClientTagDistribution tagDistribution = 
taskClientTags.get(standbyTaskId);
+                tagDistribution.addStandbyTags(hostTags.get(entry.getKey()));
+            }
+        }
+
+        return taskClientTags;
+    }
+
+    private void verifyIdealTaskDistributionReached(final Map<TaskId, 
ClientTagDistribution> taskClientTags,
+                                                    final List<String> 
tagsToCheck) {
+        for (final Map.Entry<TaskId, ClientTagDistribution> entry: 
taskClientTags.entrySet()) {
+            if (!tagsAmongStandbysAreDifferent(entry.getValue(), tagsToCheck))
+                throw new AssertionError("task " + entry.getKey() + "'s 
tag-distribution for " + tagsToCheck +
+                    " among standbys is not ideal: " + entry.getValue());
+
+            if (!tagsAmongActiveAndAllStandbysAreDifferent(entry.getValue(), 
tagsToCheck))
+                throw new AssertionError("task " + entry.getKey() + "'s 
tag-distribution for " + tagsToCheck +
+                    " between active and standbys is not ideal: " + 
entry.getValue());
+        }
+    }
+
+    private void verifyPartialTaskDistributionReached(final Map<TaskId, 
ClientTagDistribution> taskClientTags,
+                                                      final List<String> 
tagsToCheck) {
+        for (final Map.Entry<TaskId, ClientTagDistribution> entry: 
taskClientTags.entrySet()) {
+            if 
(!tagsAmongActiveAndAtLeastOneStandbyIsDifferent(entry.getValue(), tagsToCheck))
+                throw new AssertionError("task " + entry.getKey() + "'s 
tag-distribution for " + tagsToCheck +
+                    "between active and standbys is not partially ideal: " + 
entry.getValue());
+        }
+    }
+
+    private static boolean tagsAmongActiveAndAllStandbysAreDifferent(final 
ClientTagDistribution tagDistribution,
+                                                                     final 
List<String> tagsToCheck) {
+        return 
tagDistribution.standbysClientTags.stream().allMatch(standbyTags ->
+            tagsToCheck.stream().noneMatch(tag -> 
tagDistribution.activeClientTags.get(tag).equals(standbyTags.get(tag))));
+    }
+
+    private static boolean 
tagsAmongActiveAndAtLeastOneStandbyIsDifferent(final ClientTagDistribution 
tagDistribution,
+                                                                          
final List<String> tagsToCheck) {
+        return 
tagDistribution.standbysClientTags.stream().anyMatch(standbyTags ->
+            tagsToCheck.stream().noneMatch(tag -> 
tagDistribution.activeClientTags.get(tag).equals(standbyTags.get(tag))));
+    }
+
+    private static boolean tagsAmongStandbysAreDifferent(final 
ClientTagDistribution tagDistribution,
+                                                         final List<String> 
tagsToCheck) {
+        final Map<String, Integer> statistics = new HashMap<>();
+
+        for (final Map<String, String> tags : 
tagDistribution.standbysClientTags) {
+            for (final Map.Entry<String, String> tag : tags.entrySet()) {
+                if (tagsToCheck.contains(tag.getKey())) {
+                    final String tagValue = tag.getValue();
+                    final Integer tagValueOccurrence = 
statistics.getOrDefault(tagValue, 0);
+                    statistics.put(tagValue, tagValueOccurrence + 1);
+                }
+            }
+        }
+
+        return statistics.values().stream().noneMatch(occurrence -> occurrence 
> 1);
+    }
+
+    private void setupTopology(final int numOfStatefulTopologies, final int 
numOfStatelessTopologies) {
+        if (numOfStatefulTopologies + numOfStatelessTopologies > 5) {
+            throw new IllegalArgumentException("Should not have more than 5 
topologies, but have " + numOfStatefulTopologies);
+        }
+
+        for (int i = 0; i < numOfStatelessTopologies; i++) {
+            builder.addSource(null, "source" + i, null, null, null, "topic" + 
i);
+            builder.addProcessor("processor" + i, new 
MockApiProcessorSupplier<>(), "source" + i);
+        }
+
+        for (int i = numOfStatelessTopologies; i < numOfStatelessTopologies + 
numOfStatefulTopologies; i++) {
+            builder.addSource(null, "source" + i, null, null, null, "topic" + 
i);
+            builder.addProcessor("processor" + i, new 
MockApiProcessorSupplier<>(), "source" + i);
+            builder.addStateStore(new MockKeyValueStoreBuilder("store" + i, 
false), "processor" + i);
+        }
+    }
+
+    private static final class ClientTagDistribution {
+        private final TaskId taskId;
+        private final Map<String, String> activeClientTags;
+        private final List<Map<String, String>> standbysClientTags;
+
+        ClientTagDistribution(final TaskId taskId) {
+            this.taskId = taskId;
+            this.activeClientTags = new HashMap<>();
+            this.standbysClientTags = new ArrayList<>();
+        }
+
+        void addActiveTags(final Map<String, String> activeClientTags) {
+            if (!this.activeClientTags.isEmpty()) {
+                throw new IllegalStateException("Found multiple active tasks 
for " + taskId + ", this should not happen");
+            }
+            this.activeClientTags.putAll(activeClientTags);
+        }
+
+        void addStandbyTags(final Map<String, String> standbyClientTags) {
+            this.standbysClientTags.add(standbyClientTags);
+        }
+
+        @Override
+        public String toString() {
+            return "ClientTagDistribution{" +
+                "taskId=" + taskId +
+                ", activeClientTags=" + activeClientTags +
+                ", standbysClientTags=" + standbysClientTags +
+                '}';
+        }
+    }
+
+    /**
+     * Helper for building the input to createMockAdminClient in cases where 
we don't care about the actual offsets
+     * @param changelogTopics The names of all changelog topics in the topology
+     * @param topicsNumPartitions The number of partitions for the 
corresponding changelog topic, such that the number
+     *            of partitions of the ith topic in changelogTopics is given 
by the ith element of topicsNumPartitions
+     */
+    private static Map<TopicPartition, Long> getTopicPartitionOffsetsMap(final 
List<String> changelogTopics,
+                                                                         final 
List<Integer> topicsNumPartitions) {
+        if (changelogTopics.size() != topicsNumPartitions.size()) {
+            throw new IllegalStateException("Passed in " + 
changelogTopics.size() + " changelog topic names, but " +
+                    topicsNumPartitions.size() + " different numPartitions for 
the topics");
+        }
+        final Map<TopicPartition, Long> changelogEndOffsets = new HashMap<>();
+        for (int i = 0; i < changelogTopics.size(); ++i) {
+            final String topic = changelogTopics.get(i);
+            final int numPartitions = topicsNumPartitions.get(i);
+            for (int partition = 0; partition < numPartitions; ++partition) {
+                changelogEndOffsets.put(new TopicPartition(topic, partition), 
Long.MAX_VALUE);
+            }
+        }
+        return changelogEndOffsets;
+    }
+
+    private static ConsumerPartitionAssignor.Subscription 
getSubscription(final UUID processId,
+                                                                          
final Collection<TaskId> prevActiveTasks,
+                                                                          
final Map<String, String> clientTags) {
+        return new ConsumerPartitionAssignor.Subscription(
+            singletonList("source1"),
+            new SubscriptionInfo(LATEST_SUPPORTED_VERSION, 
LATEST_SUPPORTED_VERSION, processId, null,
+                getTaskOffsetSums(prevActiveTasks), (byte) 0, 0, 
clientTags).encode()
+        );
+    }
+
+    // Stub offset sums for when we only care about the prev/standby task 
sets, not the actual offsets
+    private static Map<TaskId, Long> getTaskOffsetSums(final 
Collection<TaskId> activeTasks) {
+        final Map<TaskId, Long> taskOffsetSums = 
activeTasks.stream().collect(Collectors.toMap(t -> t, t -> Task.LATEST_OFFSET));
+        taskOffsetSums.putAll(EMPTY_TASKS.stream().collect(Collectors.toMap(t 
-> t, t -> 0L)));
+        return taskOffsetSums;
+    }
+}
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
index 26efc1126b..68d2def110 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
@@ -2339,7 +2339,7 @@ public class StreamTaskTest {
         // The processor topology is missing the topics
         final ProcessorTopology topology = withSources(emptyList(), mkMap());
 
-        final TopologyException  exception = assertThrows(
+        final TopologyException exception = assertThrows(
             TopologyException.class,
             () -> new StreamTask(
                 taskId,
@@ -2358,7 +2358,7 @@ public class StreamTaskTest {
         );
 
         assertThat(exception.getMessage(), equalTo("Invalid topology: " +
-                "Topic is unknown to the topology. This may happen if 
different KafkaStreams instances of the same " +
+                "Topic " + topic1 + " is unknown to the topology. This may 
happen if different KafkaStreams instances of the same " +
                 "application execute different Topologies. Note that 
Topologies are only identical if all operators " +
                 "are added in the same order."));
     }
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentTestUtils.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentTestUtils.java
index 42a32c04b9..78c6477f38 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentTestUtils.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentTestUtils.java
@@ -65,6 +65,9 @@ public final class AssignmentTestUtils {
     public static final UUID UUID_4 = uuidForInt(4);
     public static final UUID UUID_5 = uuidForInt(5);
     public static final UUID UUID_6 = uuidForInt(6);
+    public static final UUID UUID_7 = uuidForInt(7);
+    public static final UUID UUID_8 = uuidForInt(8);
+    public static final UUID UUID_9 = uuidForInt(9);
 
     public static final TopicPartition TP_0_0 = new TopicPartition("topic0", 
0);
     public static final TopicPartition TP_0_1 = new TopicPartition("topic0", 
1);

Reply via email to