cadonna commented on a change in pull request #8497:
URL: https://github.com/apache/kafka/pull/8497#discussion_r411366111



##########
File path: 
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java
##########
@@ -89,95 +88,72 @@ public boolean assign() {
             return false;
         }
 
-        final Map<UUID, List<TaskId>> warmupTaskAssignment = 
initializeEmptyTaskAssignmentMap(sortedClients);
-        final Map<UUID, List<TaskId>> standbyTaskAssignment = 
initializeEmptyTaskAssignmentMap(sortedClients);
-        final Map<UUID, List<TaskId>> statelessActiveTaskAssignment = 
initializeEmptyTaskAssignmentMap(sortedClients);
+        final Map<TaskId, Integer> tasksToRemainingStandbys =
+            statefulTasks.stream().collect(Collectors.toMap(task -> task, t -> 
configs.numStandbyReplicas));
 
-        // ---------------- Stateful Active Tasks ---------------- //
+        final boolean followupRebalanceNeeded = 
assignStatefulActiveTasks(tasksToRemainingStandbys);

Review comment:
       I love it when a comment gets killed by a meaningful method name!

##########
File path: 
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java
##########
@@ -89,95 +88,72 @@ public boolean assign() {
             return false;
         }
 
-        final Map<UUID, List<TaskId>> warmupTaskAssignment = 
initializeEmptyTaskAssignmentMap(sortedClients);
-        final Map<UUID, List<TaskId>> standbyTaskAssignment = 
initializeEmptyTaskAssignmentMap(sortedClients);
-        final Map<UUID, List<TaskId>> statelessActiveTaskAssignment = 
initializeEmptyTaskAssignmentMap(sortedClients);
+        final Map<TaskId, Integer> tasksToRemainingStandbys =
+            statefulTasks.stream().collect(Collectors.toMap(task -> task, t -> 
configs.numStandbyReplicas));
 
-        // ---------------- Stateful Active Tasks ---------------- //
+        final boolean followupRebalanceNeeded = 
assignStatefulActiveTasks(tasksToRemainingStandbys);
 
-        final Map<UUID, List<TaskId>> statefulActiveTaskAssignment =
-            new DefaultStateConstrainedBalancedAssignor().assign(
-                statefulTasksToRankedCandidates,
-                configs.balanceFactor,
-                sortedClients,
-                clientsToNumberOfThreads,
-                tasksToCaughtUpClients
-            );
+        assignStandbyReplicaTasks(tasksToRemainingStandbys);
+
+        assignStatelessActiveTasks();
 
-        // ---------------- Warmup Replica Tasks ---------------- //
+        return followupRebalanceNeeded;
+    }
 
-        final Map<UUID, List<TaskId>> balancedStatefulActiveTaskAssignment =
+    private boolean assignStatefulActiveTasks(final Map<TaskId, Integer> 
tasksToRemainingStandbys) {
+        final Map<UUID, List<TaskId>> statefulActiveTaskAssignment =
             new DefaultBalancedAssignor().assign(
                 sortedClients,
                 statefulTasks,
                 clientsToNumberOfThreads,
                 configs.balanceFactor);

Review comment:
       prop:
   ```suggestion
           final Map<UUID, List<TaskId>> statefulActiveTaskAssignment = new 
DefaultBalancedAssignor().assign(
               sortedClients,
               statefulTasks,
               clientsToNumberOfThreads,
               configs.balanceFactor
           );
   ```

##########
File path: 
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ValidClientsByTaskLoadQueue.java
##########
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import java.util.Set;
+import java.util.UUID;
+import java.util.function.BiFunction;
+import org.apache.kafka.streams.processor.TaskId;
+
+/**
+ * Wraps a priority queue of clients and returns the next valid candidate(s) 
based on the current task assignment
+ */
+class ValidClientsByTaskLoadQueue {
+    private final PriorityQueue<UUID> clientsByTaskLoad;
+    private final BiFunction<UUID, TaskId, Boolean> validClientCriteria;
+    private final Set<UUID> uniqueClients = new HashSet<>();
+
+    ValidClientsByTaskLoadQueue(final Map<UUID, ClientState> clientStates,
+                                final BiFunction<UUID, TaskId, Boolean> 
validClientCriteria) {
+        clientsByTaskLoad = getClientPriorityQueueByTaskLoad(clientStates);
+        this.validClientCriteria = validClientCriteria;
+    }
+
+    /**
+=     * @return the next least loaded client that satisfies the given 
criteria, or null if none do
+     */
+    UUID poll(final TaskId task) {
+        final List<UUID> validClient = poll(task, 1);
+        return validClient.isEmpty() ? null : validClient.get(0);
+    }
+
+    /**
+     * @return the next N <= {@code numClientsPerTask} clients in the 
underlying priority queue that are valid
+     * candidates for the given task
+     */
+    List<UUID> poll(final TaskId task, final int numClients) {
+        final List<UUID> nextLeastLoadedValidClients = new LinkedList<>();
+        final Set<UUID> invalidPolledClients = new HashSet<>();
+        while (nextLeastLoadedValidClients.size() < numClients) {
+            UUID candidateClient;
+            while (true) {
+                candidateClient = clientsByTaskLoad.poll();
+                if (candidateClient == null) {
+                    offerAll(invalidPolledClients);
+                    return nextLeastLoadedValidClients;
+                }
+
+                if (validClientCriteria.apply(candidateClient, task)) {
+                    nextLeastLoadedValidClients.add(candidateClient);
+                    break;
+                } else {
+                    invalidPolledClients.add(candidateClient);
+                }
+            }
+        }
+        offerAll(invalidPolledClients);
+        return nextLeastLoadedValidClients;
+    }
+
+    void offerAll(final Collection<UUID> clients) {
+        for (final UUID client : clients) {
+            offer(client);
+        }
+    }
+
+    void offer(final UUID client) {
+        if (uniqueClients.contains(client)) {

Review comment:
       Q: I do not understand why we need `uniqueClients` here? Would it not 
suffice to check for `clientsByTaskLoad.contains(client)`?

##########
File path: 
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovement.java
##########
@@ -16,128 +16,94 @@
  */
 package org.apache.kafka.streams.processor.internals.assignment;
 
-import java.util.HashMap;
-import java.util.Iterator;
-import java.util.LinkedList;
+import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentUtils.taskIsCaughtUpOnClient;
+
 import java.util.List;
 import java.util.Map;
-import java.util.Objects;
-import java.util.Set;
 import java.util.SortedSet;
+import java.util.TreeSet;
 import java.util.UUID;
 import org.apache.kafka.streams.processor.TaskId;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
 public class TaskMovement {
-    private static final Logger log = 
LoggerFactory.getLogger(TaskMovement.class);
-
     final TaskId task;
-    final UUID source;
-    final UUID destination;
+    private final UUID destination;
 
-    TaskMovement(final TaskId task, final UUID source, final UUID destination) 
{
+    TaskMovement(final TaskId task, final UUID destination) {
         this.task = task;
-        this.source = source;
         this.destination = destination;
     }
 
-    @Override
-    public boolean equals(final Object o) {
-        if (this == o) {
-            return true;
-        }
-        if (o == null || getClass() != o.getClass()) {
-            return false;
-        }
-        final TaskMovement movement = (TaskMovement) o;
-        return Objects.equals(task, movement.task) &&
-                   Objects.equals(source, movement.source) &&
-                   Objects.equals(destination, movement.destination);
-    }
-
-    @Override
-    public int hashCode() {
-        return Objects.hash(task, source, destination);
-    }
-
     /**
-     * Computes the movement of tasks from the state constrained to the 
balanced assignment, up to the configured
-     * {@code max.warmup.replicas}. A movement corresponds to a warmup replica 
on the destination client, with
-     * a few exceptional cases:
-     * <p>
-     * 1. Tasks whose destination clients are caught-up, or whose source 
clients are not caught-up, will be moved
-     * immediately from the source to the destination in the state constrained 
assignment
-     * 2. Tasks whose destination client previously had this task as a standby 
will not be counted towards the total
-     * {@code max.warmup.replicas}. Instead they will be counted against that 
task's total {@code num.standby.replicas}.
-     *
-     * @param statefulActiveTaskAssignment the initial, state constrained 
assignment, with the source clients
-     * @param balancedStatefulActiveTaskAssignment the final, balanced 
assignment, with the destination clients
-     * @return list of the task movements from statefulActiveTaskAssignment to 
balancedStatefulActiveTaskAssignment
+     * @return whether any warmup replicas were assigned
      */
-    static List<TaskMovement> getMovements(final Map<UUID, List<TaskId>> 
statefulActiveTaskAssignment,
-                                           final Map<UUID, List<TaskId>> 
balancedStatefulActiveTaskAssignment,
-                                           final Map<TaskId, SortedSet<UUID>> 
tasksToCaughtUpClients,
-                                           final Map<UUID, ClientState> 
clientStates,
-                                           final Map<TaskId, Integer> 
tasksToRemainingStandbys,
-                                           final int maxWarmupReplicas) {
-        if (statefulActiveTaskAssignment.size() != 
balancedStatefulActiveTaskAssignment.size()) {
-            throw new IllegalStateException("Tried to compute movements but 
assignments differ in size.");
-        }
+    static boolean assignTaskMovements(final Map<UUID, List<TaskId>> 
statefulActiveTaskAssignment,
+                                       final Map<TaskId, SortedSet<UUID>> 
tasksToCaughtUpClients,
+                                       final Map<UUID, ClientState> 
clientStates,
+                                       final Map<TaskId, Integer> 
tasksToRemainingStandbys,
+                                       final int maxWarmupReplicas) {
+        boolean warmupReplicasAssigned = false;
+
+        final ValidClientsByTaskLoadQueue clientsByTaskLoad =
+            new ValidClientsByTaskLoadQueue(
+                clientStates,
+                (client, task) -> taskIsCaughtUpOnClient(task, client, 
tasksToCaughtUpClients)
+            );

Review comment:
       prop:
   ```suggestion
           final ValidClientsByTaskLoadQueue clientsByTaskLoad = new 
ValidClientsByTaskLoadQueue(
               clientStates,
               (client, task) -> taskIsCaughtUpOnClient(task, client, 
tasksToCaughtUpClients)
           );
   ```

##########
File path: 
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ValidClientsByTaskLoadQueue.java
##########
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import java.util.Set;
+import java.util.UUID;
+import java.util.function.BiFunction;
+import org.apache.kafka.streams.processor.TaskId;
+
+/**
+ * Wraps a priority queue of clients and returns the next valid candidate(s) 
based on the current task assignment
+ */
+class ValidClientsByTaskLoadQueue {

Review comment:
       req: Please add unit tests for this class.

##########
File path: 
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovement.java
##########
@@ -16,128 +16,94 @@
  */
 package org.apache.kafka.streams.processor.internals.assignment;
 
-import java.util.HashMap;
-import java.util.Iterator;
-import java.util.LinkedList;
+import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentUtils.taskIsCaughtUpOnClient;
+
 import java.util.List;
 import java.util.Map;
-import java.util.Objects;
-import java.util.Set;
 import java.util.SortedSet;
+import java.util.TreeSet;
 import java.util.UUID;
 import org.apache.kafka.streams.processor.TaskId;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
 public class TaskMovement {
-    private static final Logger log = 
LoggerFactory.getLogger(TaskMovement.class);
-
     final TaskId task;
-    final UUID source;
-    final UUID destination;
+    private final UUID destination;
 
-    TaskMovement(final TaskId task, final UUID source, final UUID destination) 
{
+    TaskMovement(final TaskId task, final UUID destination) {
         this.task = task;
-        this.source = source;
         this.destination = destination;
     }
 
-    @Override
-    public boolean equals(final Object o) {
-        if (this == o) {
-            return true;
-        }
-        if (o == null || getClass() != o.getClass()) {
-            return false;
-        }
-        final TaskMovement movement = (TaskMovement) o;
-        return Objects.equals(task, movement.task) &&
-                   Objects.equals(source, movement.source) &&
-                   Objects.equals(destination, movement.destination);
-    }
-
-    @Override
-    public int hashCode() {
-        return Objects.hash(task, source, destination);
-    }
-
     /**
-     * Computes the movement of tasks from the state constrained to the 
balanced assignment, up to the configured
-     * {@code max.warmup.replicas}. A movement corresponds to a warmup replica 
on the destination client, with
-     * a few exceptional cases:
-     * <p>
-     * 1. Tasks whose destination clients are caught-up, or whose source 
clients are not caught-up, will be moved
-     * immediately from the source to the destination in the state constrained 
assignment
-     * 2. Tasks whose destination client previously had this task as a standby 
will not be counted towards the total
-     * {@code max.warmup.replicas}. Instead they will be counted against that 
task's total {@code num.standby.replicas}.
-     *
-     * @param statefulActiveTaskAssignment the initial, state constrained 
assignment, with the source clients
-     * @param balancedStatefulActiveTaskAssignment the final, balanced 
assignment, with the destination clients
-     * @return list of the task movements from statefulActiveTaskAssignment to 
balancedStatefulActiveTaskAssignment
+     * @return whether any warmup replicas were assigned
      */
-    static List<TaskMovement> getMovements(final Map<UUID, List<TaskId>> 
statefulActiveTaskAssignment,
-                                           final Map<UUID, List<TaskId>> 
balancedStatefulActiveTaskAssignment,
-                                           final Map<TaskId, SortedSet<UUID>> 
tasksToCaughtUpClients,
-                                           final Map<UUID, ClientState> 
clientStates,
-                                           final Map<TaskId, Integer> 
tasksToRemainingStandbys,
-                                           final int maxWarmupReplicas) {
-        if (statefulActiveTaskAssignment.size() != 
balancedStatefulActiveTaskAssignment.size()) {
-            throw new IllegalStateException("Tried to compute movements but 
assignments differ in size.");
-        }
+    static boolean assignTaskMovements(final Map<UUID, List<TaskId>> 
statefulActiveTaskAssignment,
+                                       final Map<TaskId, SortedSet<UUID>> 
tasksToCaughtUpClients,
+                                       final Map<UUID, ClientState> 
clientStates,
+                                       final Map<TaskId, Integer> 
tasksToRemainingStandbys,
+                                       final int maxWarmupReplicas) {
+        boolean warmupReplicasAssigned = false;
+
+        final ValidClientsByTaskLoadQueue clientsByTaskLoad =
+            new ValidClientsByTaskLoadQueue(
+                clientStates,
+                (client, task) -> taskIsCaughtUpOnClient(task, client, 
tasksToCaughtUpClients)
+            );
 
-        final Map<TaskId, UUID> taskToDestinationClient = new HashMap<>();
-        for (final Map.Entry<UUID, List<TaskId>> clientEntry : 
balancedStatefulActiveTaskAssignment.entrySet()) {
-            final UUID destination = clientEntry.getKey();
-            for (final TaskId task : clientEntry.getValue()) {
-                taskToDestinationClient.put(task, destination);
+        final SortedSet<TaskMovement> taskMovements = new TreeSet<>(
+            (movement, other) -> {
+                final int numCaughtUpClients = 
tasksToCaughtUpClients.get(movement.task).size();
+                final int otherNumCaughtUpClients = 
tasksToCaughtUpClients.get(other.task).size();
+                if (numCaughtUpClients != otherNumCaughtUpClients) {
+                    return numCaughtUpClients - otherNumCaughtUpClients;
+                } else {
+                    return movement.task.compareTo(other.task);
+                }
             }
+        );
+
+        for (final Map.Entry<UUID, List<TaskId>> assignmentEntry : 
statefulActiveTaskAssignment.entrySet()) {
+            final UUID client = assignmentEntry.getKey();
+            final ClientState state = clientStates.get(client);
+            for (final TaskId task : assignmentEntry.getValue()) {
+                if (taskIsCaughtUpOnClient(task, client, 
tasksToCaughtUpClients)) {
+                    state.assignActive(task);
+                } else {
+                    final TaskMovement taskMovement = new TaskMovement(task,  
client);
+                    taskMovements.add(taskMovement);
+                }
+            }
+            clientsByTaskLoad.offer(client);
         }
 
-        int remainingAllowedWarmupReplicas = maxWarmupReplicas;
-        final List<TaskMovement> movements = new LinkedList<>();
-        for (final Map.Entry<UUID, List<TaskId>> sourceClientEntry : 
statefulActiveTaskAssignment.entrySet()) {
-            final UUID source = sourceClientEntry.getKey();
+        int remainingWarmupReplicas = maxWarmupReplicas;
+        for (final TaskMovement movement : taskMovements) {
+            final UUID leastLoadedClient = 
clientsByTaskLoad.poll(movement.task);
+            if (leastLoadedClient == null) {
+                throw new IllegalStateException("Tried to move task to 
caught-up client but none exist");

Review comment:
       > We can possibly clarify it my making the name of 
taskIsCaughtUpOnClient more complete: eitherClientIsCaughtUpOnTaskOrNoClientIs.
   
   Agree on that
   
   > But I wouldn't hesitate to also write a nice letter to future us here as a 
comment.
   
   What about writing the nice letter to future us in the exception message 
instead of a comment?
   

##########
File path: 
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovement.java
##########
@@ -16,128 +16,94 @@
  */
 package org.apache.kafka.streams.processor.internals.assignment;
 
-import java.util.HashMap;
-import java.util.Iterator;
-import java.util.LinkedList;
+import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentUtils.taskIsCaughtUpOnClient;
+
 import java.util.List;
 import java.util.Map;
-import java.util.Objects;
-import java.util.Set;
 import java.util.SortedSet;
+import java.util.TreeSet;
 import java.util.UUID;
 import org.apache.kafka.streams.processor.TaskId;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
 public class TaskMovement {
-    private static final Logger log = 
LoggerFactory.getLogger(TaskMovement.class);
-
     final TaskId task;
-    final UUID source;
-    final UUID destination;
+    private final UUID destination;
 
-    TaskMovement(final TaskId task, final UUID source, final UUID destination) 
{
+    TaskMovement(final TaskId task, final UUID destination) {
         this.task = task;
-        this.source = source;
         this.destination = destination;
     }
 
-    @Override
-    public boolean equals(final Object o) {
-        if (this == o) {
-            return true;
-        }
-        if (o == null || getClass() != o.getClass()) {
-            return false;
-        }
-        final TaskMovement movement = (TaskMovement) o;
-        return Objects.equals(task, movement.task) &&
-                   Objects.equals(source, movement.source) &&
-                   Objects.equals(destination, movement.destination);
-    }
-
-    @Override
-    public int hashCode() {
-        return Objects.hash(task, source, destination);
-    }
-
     /**
-     * Computes the movement of tasks from the state constrained to the 
balanced assignment, up to the configured
-     * {@code max.warmup.replicas}. A movement corresponds to a warmup replica 
on the destination client, with
-     * a few exceptional cases:
-     * <p>
-     * 1. Tasks whose destination clients are caught-up, or whose source 
clients are not caught-up, will be moved
-     * immediately from the source to the destination in the state constrained 
assignment
-     * 2. Tasks whose destination client previously had this task as a standby 
will not be counted towards the total
-     * {@code max.warmup.replicas}. Instead they will be counted against that 
task's total {@code num.standby.replicas}.
-     *
-     * @param statefulActiveTaskAssignment the initial, state constrained 
assignment, with the source clients
-     * @param balancedStatefulActiveTaskAssignment the final, balanced 
assignment, with the destination clients
-     * @return list of the task movements from statefulActiveTaskAssignment to 
balancedStatefulActiveTaskAssignment
+     * @return whether any warmup replicas were assigned
      */
-    static List<TaskMovement> getMovements(final Map<UUID, List<TaskId>> 
statefulActiveTaskAssignment,
-                                           final Map<UUID, List<TaskId>> 
balancedStatefulActiveTaskAssignment,
-                                           final Map<TaskId, SortedSet<UUID>> 
tasksToCaughtUpClients,
-                                           final Map<UUID, ClientState> 
clientStates,
-                                           final Map<TaskId, Integer> 
tasksToRemainingStandbys,
-                                           final int maxWarmupReplicas) {
-        if (statefulActiveTaskAssignment.size() != 
balancedStatefulActiveTaskAssignment.size()) {
-            throw new IllegalStateException("Tried to compute movements but 
assignments differ in size.");
-        }
+    static boolean assignTaskMovements(final Map<UUID, List<TaskId>> 
statefulActiveTaskAssignment,
+                                       final Map<TaskId, SortedSet<UUID>> 
tasksToCaughtUpClients,
+                                       final Map<UUID, ClientState> 
clientStates,
+                                       final Map<TaskId, Integer> 
tasksToRemainingStandbys,
+                                       final int maxWarmupReplicas) {
+        boolean warmupReplicasAssigned = false;
+
+        final ValidClientsByTaskLoadQueue clientsByTaskLoad =
+            new ValidClientsByTaskLoadQueue(
+                clientStates,
+                (client, task) -> taskIsCaughtUpOnClient(task, client, 
tasksToCaughtUpClients)
+            );
 
-        final Map<TaskId, UUID> taskToDestinationClient = new HashMap<>();
-        for (final Map.Entry<UUID, List<TaskId>> clientEntry : 
balancedStatefulActiveTaskAssignment.entrySet()) {
-            final UUID destination = clientEntry.getKey();
-            for (final TaskId task : clientEntry.getValue()) {
-                taskToDestinationClient.put(task, destination);
+        final SortedSet<TaskMovement> taskMovements = new TreeSet<>(
+            (movement, other) -> {
+                final int numCaughtUpClients = 
tasksToCaughtUpClients.get(movement.task).size();
+                final int otherNumCaughtUpClients = 
tasksToCaughtUpClients.get(other.task).size();
+                if (numCaughtUpClients != otherNumCaughtUpClients) {
+                    return numCaughtUpClients - otherNumCaughtUpClients;
+                } else {
+                    return movement.task.compareTo(other.task);
+                }
             }
+        );
+
+        for (final Map.Entry<UUID, List<TaskId>> assignmentEntry : 
statefulActiveTaskAssignment.entrySet()) {
+            final UUID client = assignmentEntry.getKey();
+            final ClientState state = clientStates.get(client);
+            for (final TaskId task : assignmentEntry.getValue()) {
+                if (taskIsCaughtUpOnClient(task, client, 
tasksToCaughtUpClients)) {
+                    state.assignActive(task);
+                } else {
+                    final TaskMovement taskMovement = new TaskMovement(task,  
client);
+                    taskMovements.add(taskMovement);
+                }
+            }
+            clientsByTaskLoad.offer(client);
         }
 
-        int remainingAllowedWarmupReplicas = maxWarmupReplicas;
-        final List<TaskMovement> movements = new LinkedList<>();
-        for (final Map.Entry<UUID, List<TaskId>> sourceClientEntry : 
statefulActiveTaskAssignment.entrySet()) {
-            final UUID source = sourceClientEntry.getKey();
+        int remainingWarmupReplicas = maxWarmupReplicas;
+        for (final TaskMovement movement : taskMovements) {
+            final UUID leastLoadedClient = 
clientsByTaskLoad.poll(movement.task);
+            if (leastLoadedClient == null) {
+                throw new IllegalStateException("Tried to move task to 
caught-up client but none exist");
+            }
 
-            final Iterator<TaskId> sourceClientTasksIterator = 
sourceClientEntry.getValue().iterator();
-            while (sourceClientTasksIterator.hasNext()) {
-                final TaskId task = sourceClientTasksIterator.next();
-                final UUID destination = taskToDestinationClient.get(task);
-                if (destination == null) {
-                    log.error("Task {} is assigned to client {} in initial 
assignment but has no owner in the final " +
-                                  "balanced assignment.", task, source);
-                    throw new IllegalStateException("Found task in initial 
assignment that was not assigned in the final.");
-                } else if (!source.equals(destination)) {
-                    if (destinationClientIsCaughtUp(task, destination, 
tasksToCaughtUpClients)) {
-                        sourceClientTasksIterator.remove();
-                        
statefulActiveTaskAssignment.get(destination).add(task);
-                    } else {
-                        if 
(clientStates.get(destination).prevStandbyTasks().contains(task)
-                                && tasksToRemainingStandbys.get(task) > 0
-                        ) {
-                            decrementRemainingStandbys(task, 
tasksToRemainingStandbys);
-                        } else {
-                            --remainingAllowedWarmupReplicas;
-                        }
+            final ClientState sourceClientState = 
clientStates.get(leastLoadedClient);
+            sourceClientState.assignActive(movement.task);
 
-                        movements.add(new TaskMovement(task, source, 
destination));
-                        if (remainingAllowedWarmupReplicas == 0) {
-                            return movements;
-                        }
-                    }
-                }
+            final ClientState destinationClientState = 
clientStates.get(movement.destination);
+            if 
(destinationClientState.prevStandbyTasks().contains(movement.task) && 
tasksToRemainingStandbys.get(movement.task) > 0) {

Review comment:
       Q: Why do we even care at all whether the task was running on the 
client? What if we just assign a real stand-by task if we have a spare one?  
   

##########
File path: 
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java
##########
@@ -89,95 +88,72 @@ public boolean assign() {
             return false;
         }
 
-        final Map<UUID, List<TaskId>> warmupTaskAssignment = 
initializeEmptyTaskAssignmentMap(sortedClients);
-        final Map<UUID, List<TaskId>> standbyTaskAssignment = 
initializeEmptyTaskAssignmentMap(sortedClients);
-        final Map<UUID, List<TaskId>> statelessActiveTaskAssignment = 
initializeEmptyTaskAssignmentMap(sortedClients);
+        final Map<TaskId, Integer> tasksToRemainingStandbys =
+            statefulTasks.stream().collect(Collectors.toMap(task -> task, t -> 
configs.numStandbyReplicas));
 
-        // ---------------- Stateful Active Tasks ---------------- //
+        final boolean followupRebalanceNeeded = 
assignStatefulActiveTasks(tasksToRemainingStandbys);
 
-        final Map<UUID, List<TaskId>> statefulActiveTaskAssignment =
-            new DefaultStateConstrainedBalancedAssignor().assign(
-                statefulTasksToRankedCandidates,
-                configs.balanceFactor,
-                sortedClients,
-                clientsToNumberOfThreads,
-                tasksToCaughtUpClients
-            );
+        assignStandbyReplicaTasks(tasksToRemainingStandbys);
+
+        assignStatelessActiveTasks();
 
-        // ---------------- Warmup Replica Tasks ---------------- //
+        return followupRebalanceNeeded;
+    }
 
-        final Map<UUID, List<TaskId>> balancedStatefulActiveTaskAssignment =
+    private boolean assignStatefulActiveTasks(final Map<TaskId, Integer> 
tasksToRemainingStandbys) {
+        final Map<UUID, List<TaskId>> statefulActiveTaskAssignment =
             new DefaultBalancedAssignor().assign(
                 sortedClients,
                 statefulTasks,
                 clientsToNumberOfThreads,
                 configs.balanceFactor);
 
-        final Map<TaskId, Integer> tasksToRemainingStandbys =
-            statefulTasks.stream().collect(Collectors.toMap(task -> task, t -> 
configs.numStandbyReplicas));
-
-        final List<TaskMovement> movements = getMovements(
+        return assignTaskMovements(
             statefulActiveTaskAssignment,
-            balancedStatefulActiveTaskAssignment,
             tasksToCaughtUpClients,
             clientStates,
             tasksToRemainingStandbys,
-            configs.maxWarmupReplicas);
-
-        for (final TaskMovement movement : movements) {
-            warmupTaskAssignment.get(movement.destination).add(movement.task);
-        }
-
-        // ---------------- Standby Replica Tasks ---------------- //
-
-        final List<Map<UUID, List<TaskId>>> allTaskAssignmentMaps = asList(
-            statefulActiveTaskAssignment,
-            warmupTaskAssignment,
-            standbyTaskAssignment,
-            statelessActiveTaskAssignment
+            configs.maxWarmupReplicas
         );
+    }
 
-        final ValidClientsByTaskLoadQueue<UUID> clientsByStandbyTaskLoad =
-            new ValidClientsByTaskLoadQueue<>(
-                getClientPriorityQueueByTaskLoad(allTaskAssignmentMaps),
-                allTaskAssignmentMaps
+    private void assignStandbyReplicaTasks(final Map<TaskId, Integer> 
tasksToRemainingStandbys) {
+        final ValidClientsByTaskLoadQueue standbyTaskClientsByTaskLoad =
+            new ValidClientsByTaskLoadQueue(
+                clientStates,
+                (client, task) -> 
!clientStates.get(client).assignedTasks().contains(task)
             );

Review comment:
       prop:
   ```suggestion
           final ValidClientsByTaskLoadQueue standbyTaskClientsByTaskLoad = new 
ValidClientsByTaskLoadQueue(
               clientStates,
               (client, task) -> 
!clientStates.get(client).assignedTasks().contains(task)
           );
   ```

##########
File path: 
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ValidClientsByTaskLoadQueue.java
##########
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import java.util.Set;
+import java.util.UUID;
+import java.util.function.BiFunction;
+import org.apache.kafka.streams.processor.TaskId;
+
+/**
+ * Wraps a priority queue of clients and returns the next valid candidate(s) 
based on the current task assignment
+ */
+class ValidClientsByTaskLoadQueue {
+    private final PriorityQueue<UUID> clientsByTaskLoad;
+    private final BiFunction<UUID, TaskId, Boolean> validClientCriteria;
+
+    ValidClientsByTaskLoadQueue(final Map<UUID, ClientState> clientStates,
+                                final BiFunction<UUID, TaskId, Boolean> 
validClientCriteria) {

Review comment:
       I do not remember having contributed to this awesomeness. It is all 
@ableegoldman 's merit.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to