ableegoldman commented on a change in pull request #8588:
URL: https://github.com/apache/kafka/pull/8588#discussion_r424729806



##########
File path: 
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ConstrainedPrioritySet.java
##########
@@ -16,77 +16,58 @@
  */
 package org.apache.kafka.streams.processor.internals.assignment;
 
+import org.apache.kafka.streams.processor.TaskId;
+
 import java.util.Collection;
+import java.util.Comparator;
 import java.util.HashSet;
-import java.util.LinkedList;
-import java.util.List;
-import java.util.Map;
 import java.util.PriorityQueue;
 import java.util.Set;
 import java.util.UUID;
 import java.util.function.BiFunction;
-import org.apache.kafka.streams.processor.TaskId;
+import java.util.function.Function;
 
 /**
  * Wraps a priority queue of clients and returns the next valid candidate(s) 
based on the current task assignment
  */
-class ValidClientsByTaskLoadQueue {
+class ConstrainedPrioritySet {
 
     private final PriorityQueue<UUID> clientsByTaskLoad;
-    private final BiFunction<UUID, TaskId, Boolean> validClientCriteria;
+    private final BiFunction<UUID, TaskId, Boolean> constraint;
     private final Set<UUID> uniqueClients = new HashSet<>();
 
-    ValidClientsByTaskLoadQueue(final Map<UUID, ClientState> clientStates,
-                                final BiFunction<UUID, TaskId, Boolean> 
validClientCriteria) {
-        this.validClientCriteria = validClientCriteria;
-
-        clientsByTaskLoad = new PriorityQueue<>(
-            (client, other) -> {
-                final double clientTaskLoad = 
clientStates.get(client).taskLoad();
-                final double otherTaskLoad = 
clientStates.get(other).taskLoad();
-                if (clientTaskLoad < otherTaskLoad) {
-                    return -1;
-                } else if (clientTaskLoad > otherTaskLoad) {
-                    return 1;
-                } else {
-                    return client.compareTo(other);
-                }
-            });
+    ConstrainedPrioritySet(final BiFunction<UUID, TaskId, Boolean> constraint,
+                           final Function<UUID, Double> weight) {
+        this.constraint = constraint;
+        clientsByTaskLoad = new 
PriorityQueue<>(Comparator.comparing(weight).thenComparing(clientId -> 
clientId));
     }
 
     /**
      * @return the next least loaded client that satisfies the given criteria, 
or null if none do
      */
-    UUID poll(final TaskId task) {
-        final List<UUID> validClient = poll(task, 1);
-        return validClient.isEmpty() ? null : validClient.get(0);
-    }
-
-    /**
-     * @return the next N <= {@code numClientsPerTask} clients in the 
underlying priority queue that are valid candidates for the given task
-     */
-    List<UUID> poll(final TaskId task, final int numClients) {
-        final List<UUID> nextLeastLoadedValidClients = new LinkedList<>();
+    UUID poll(final TaskId task, final Function<UUID, Boolean> 
extraConstraint) {

Review comment:
       I'm not sure I see how the returned clients could ever be different 
using "poll N clients" vs "poll N times". Only the clients which are getting a 
new task assigned will have their weight changed while in the middle of an N 
poll, and once we assign this task to that client it no longer meets the 
criteria so we don't care about it anyway right?
   
   The reason for the "poll N clients" method was to save on some of the 
poll-and-reoffer of clients that don't meet the criteria, but I don't think 
that's really worth worrying over. I'm fine with whatever code is easiest to 
read -- just want to understand why this affects the balance?

##########
File path: 
streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovementTest.java
##########
@@ -35,262 +44,161 @@
 import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_2;
 import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_3;
 import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.getClientStatesMap;
+import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.hasProperty;
 import static 
org.apache.kafka.streams.processor.internals.assignment.TaskMovement.assignTaskMovements;
 import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.equalTo;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.assertFalse;
-
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import java.util.SortedSet;
-import java.util.UUID;
-import java.util.stream.Collectors;
-import org.apache.kafka.streams.processor.TaskId;
-import org.junit.Test;
+import static org.hamcrest.Matchers.is;
 
 public class TaskMovementTest {
-    private final ClientState client1 = new ClientState(1);
-    private final ClientState client2 = new ClientState(1);
-    private final ClientState client3 = new ClientState(1);
-
-    private final Map<UUID, ClientState> clientStates = 
getClientStatesMap(client1, client2, client3);
-
-    private final Map<UUID, List<TaskId>> emptyWarmupAssignment = mkMap(
-        mkEntry(UUID_1, EMPTY_TASK_LIST),
-        mkEntry(UUID_2, EMPTY_TASK_LIST),
-        mkEntry(UUID_3, EMPTY_TASK_LIST)
-    );
-
     @Test
     public void 
shouldAssignTasksToClientsAndReturnFalseWhenAllClientsCaughtUp() {
         final int maxWarmupReplicas = Integer.MAX_VALUE;
         final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, 
TASK_1_0, TASK_1_1, TASK_1_2);
 
-        final Map<UUID, List<TaskId>> balancedAssignment = mkMap(
-            mkEntry(UUID_1, asList(TASK_0_0, TASK_1_0)),
-            mkEntry(UUID_2, asList(TASK_0_1, TASK_1_1)),
-            mkEntry(UUID_3, asList(TASK_0_2, TASK_1_2))
-        );
-
         final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = new 
HashMap<>();
         for (final TaskId task : allTasks) {
             tasksToCaughtUpClients.put(task, mkSortedSet(UUID_1, UUID_2, 
UUID_3));
         }
-        
-        assertFalse(
+
+        final ClientState client1 = 
getClientStateWithActiveAssignment(asList(TASK_0_0, TASK_1_0));
+        final ClientState client2 = 
getClientStateWithActiveAssignment(asList(TASK_0_1, TASK_1_1));
+        final ClientState client3 = 
getClientStateWithActiveAssignment(asList(TASK_0_2, TASK_1_2));
+
+        assertThat(
             assignTaskMovements(
-                balancedAssignment,
                 tasksToCaughtUpClients,
-                clientStates,
-                getMapWithNumStandbys(allTasks, 1),
-                maxWarmupReplicas)
+                getClientStatesMap(client1, client2, client3),
+                maxWarmupReplicas),
+            is(false)
         );
-
-        verifyClientStateAssignments(balancedAssignment, 
emptyWarmupAssignment);
     }
 
     @Test
     public void 
shouldAssignAllTasksToClientsAndReturnFalseIfNoClientsAreCaughtUp() {
-        final int maxWarmupReplicas = 2;
-        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, 
TASK_1_0, TASK_1_1, TASK_1_2);
+        final int maxWarmupReplicas = Integer.MAX_VALUE;
 
-        final Map<UUID, List<TaskId>> balancedAssignment = mkMap(
-            mkEntry(UUID_1, asList(TASK_0_0, TASK_1_0)),
-            mkEntry(UUID_2, asList(TASK_0_1, TASK_1_1)),
-            mkEntry(UUID_3, asList(TASK_0_2, TASK_1_2))
-        );
+        final ClientState client1 = 
getClientStateWithActiveAssignment(asList(TASK_0_0, TASK_1_0));
+        final ClientState client2 = 
getClientStateWithActiveAssignment(asList(TASK_0_1, TASK_1_1));
+        final ClientState client3 = 
getClientStateWithActiveAssignment(asList(TASK_0_2, TASK_1_2));
 
-        assertFalse(
+        assertThat(
             assignTaskMovements(
-                balancedAssignment,
                 emptyMap(),
-                clientStates,
-                getMapWithNumStandbys(allTasks, 1),
-                maxWarmupReplicas)
+                getClientStatesMap(client1, client2, client3),
+                maxWarmupReplicas),
+            is(false)
         );
-        verifyClientStateAssignments(balancedAssignment, 
emptyWarmupAssignment);
     }
 
     @Test
     public void 
shouldMoveTasksToCaughtUpClientsAndAssignWarmupReplicasInTheirPlace() {
         final int maxWarmupReplicas = Integer.MAX_VALUE;
-        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2);
+        final ClientState client1 = 
getClientStateWithActiveAssignment(singletonList(TASK_0_0));
+        final ClientState client2 = 
getClientStateWithActiveAssignment(singletonList(TASK_0_1));
+        final ClientState client3 = 
getClientStateWithActiveAssignment(singletonList(TASK_0_2));
+        final Map<UUID, ClientState> clientStates = 
getClientStatesMap(client1, client2, client3);
 
-        final Map<UUID, List<TaskId>> balancedAssignment = mkMap(
-            mkEntry(UUID_1, singletonList(TASK_0_0)),
-            mkEntry(UUID_2, singletonList(TASK_0_1)),
-            mkEntry(UUID_3, singletonList(TASK_0_2))
+        final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = mkMap(
+            mkEntry(TASK_0_0, mkSortedSet(UUID_1)),
+            mkEntry(TASK_0_1, mkSortedSet(UUID_3)),
+            mkEntry(TASK_0_2, mkSortedSet(UUID_2))
         );
 
-        final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = new 
HashMap<>();
-        tasksToCaughtUpClients.put(TASK_0_0, mkSortedSet(UUID_1));
-        tasksToCaughtUpClients.put(TASK_0_1, mkSortedSet(UUID_3));
-        tasksToCaughtUpClients.put(TASK_0_2, mkSortedSet(UUID_2));
-
-        final Map<UUID, List<TaskId>> expectedActiveTaskAssignment = mkMap(
-            mkEntry(UUID_1, singletonList(TASK_0_0)),
-            mkEntry(UUID_2, singletonList(TASK_0_2)),
-            mkEntry(UUID_3, singletonList(TASK_0_1))
-        );
-
-        final Map<UUID, List<TaskId>> expectedWarmupTaskAssignment = mkMap(
-            mkEntry(UUID_1, EMPTY_TASK_LIST),
-            mkEntry(UUID_2, singletonList(TASK_0_1)),
-            mkEntry(UUID_3, singletonList(TASK_0_2))
-        );
-
-        assertTrue(
+        assertThat(
+            "should have assigned movements",
             assignTaskMovements(
-                balancedAssignment,
                 tasksToCaughtUpClients,
                 clientStates,
-                getMapWithNumStandbys(allTasks, 1),
-                maxWarmupReplicas)
-        );
-        verifyClientStateAssignments(expectedActiveTaskAssignment, 
expectedWarmupTaskAssignment);
-    }
-
-    @Test
-    public void shouldProduceBalancedAndStateConstrainedAssignment() {

Review comment:
       IIRC this was covering an edge case where it might produce an unbalanced 
assignment. But it may be moot at this point (and besides, we don't necessarily 
need to produce a perfectly balanced assignment here)

##########
File path: 
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovement.java
##########
@@ -53,75 +67,94 @@ private static boolean 
taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(final Task
     /**
      * @return whether any warmup replicas were assigned
      */
-    static boolean assignTaskMovements(final Map<UUID, List<TaskId>> 
statefulActiveTaskAssignment,
-                                       final Map<TaskId, SortedSet<UUID>> 
tasksToCaughtUpClients,
+    static boolean assignTaskMovements(final Map<TaskId, SortedSet<UUID>> 
tasksToCaughtUpClients,
                                        final Map<UUID, ClientState> 
clientStates,
-                                       final Map<TaskId, Integer> 
tasksToRemainingStandbys,
                                        final int maxWarmupReplicas) {
-        boolean warmupReplicasAssigned = false;
+        final BiFunction<UUID, TaskId, Boolean> caughtUpPredicate =
+            (client, task) -> 
taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(task, client, 
tasksToCaughtUpClients);
 
-        final ValidClientsByTaskLoadQueue clientsByTaskLoad = new 
ValidClientsByTaskLoadQueue(
-            clientStates,
-            (client, task) -> 
taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(task, client, 
tasksToCaughtUpClients)
+        final ConstrainedPrioritySet clientsByTaskLoad = new 
ConstrainedPrioritySet(

Review comment:
       nit: I know I named this in the first place but can we change it to 
`caughtUpClientsByTaskLoad` or something?

##########
File path: 
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovement.java
##########
@@ -53,75 +67,94 @@ private static boolean 
taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(final Task
     /**
      * @return whether any warmup replicas were assigned
      */
-    static boolean assignTaskMovements(final Map<UUID, List<TaskId>> 
statefulActiveTaskAssignment,
-                                       final Map<TaskId, SortedSet<UUID>> 
tasksToCaughtUpClients,
+    static boolean assignTaskMovements(final Map<TaskId, SortedSet<UUID>> 
tasksToCaughtUpClients,
                                        final Map<UUID, ClientState> 
clientStates,
-                                       final Map<TaskId, Integer> 
tasksToRemainingStandbys,
                                        final int maxWarmupReplicas) {
-        boolean warmupReplicasAssigned = false;
+        final BiFunction<UUID, TaskId, Boolean> caughtUpPredicate =
+            (client, task) -> 
taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(task, client, 
tasksToCaughtUpClients);
 
-        final ValidClientsByTaskLoadQueue clientsByTaskLoad = new 
ValidClientsByTaskLoadQueue(
-            clientStates,
-            (client, task) -> 
taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(task, client, 
tasksToCaughtUpClients)
+        final ConstrainedPrioritySet clientsByTaskLoad = new 
ConstrainedPrioritySet(
+            caughtUpPredicate,
+            client -> clientStates.get(client).taskLoad()
         );
 
-        final SortedSet<TaskMovement> taskMovements = new TreeSet<>(
-            (movement, other) -> {
-                final int numCaughtUpClients = movement.caughtUpClients.size();
-                final int otherNumCaughtUpClients = 
other.caughtUpClients.size();
-                if (numCaughtUpClients != otherNumCaughtUpClients) {
-                    return Integer.compare(numCaughtUpClients, 
otherNumCaughtUpClients);
-                } else {
-                    return movement.task.compareTo(other.task);
-                }
-            }
+        final Queue<TaskMovement> taskMovements = new PriorityQueue<>(
+            
Comparator.comparing(TaskMovement::numCaughtUpClients).thenComparing(TaskMovement::task)
         );
 
-        for (final Map.Entry<UUID, List<TaskId>> assignmentEntry : 
statefulActiveTaskAssignment.entrySet()) {
-            final UUID client = assignmentEntry.getKey();
-            final ClientState state = clientStates.get(client);
-            for (final TaskId task : assignmentEntry.getValue()) {
-                if (taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(task, 
client, tasksToCaughtUpClients)) {
-                    state.assignActive(task);
-                } else {
-                    final TaskMovement taskMovement = new TaskMovement(task, 
client, tasksToCaughtUpClients.get(task));
-                    taskMovements.add(taskMovement);
+        for (final Map.Entry<UUID, ClientState> clientStateEntry : 
clientStates.entrySet()) {
+            final UUID client = clientStateEntry.getKey();
+            final ClientState state = clientStateEntry.getValue();
+            for (final TaskId task : state.activeTasks()) {
+                // if the desired client is not caught up, and there is 
another client that _is_ caught up, then
+                // we schedule a movement, so we can move the active task to 
the caught-up client. We'll try to
+                // assign a warm-up to the desired client so that we can move 
it later on.
+                if (!taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(task, 
client, tasksToCaughtUpClients)) {
+                    taskMovements.add(new TaskMovement(task, client, 
tasksToCaughtUpClients.get(task)));
                 }
             }
             clientsByTaskLoad.offer(client);
         }
 
+        final boolean movementsNeeded = !taskMovements.isEmpty();
+
         final AtomicInteger remainingWarmupReplicas = new 
AtomicInteger(maxWarmupReplicas);
         for (final TaskMovement movement : taskMovements) {
-            final UUID sourceClient = clientsByTaskLoad.poll(movement.task);
-            if (sourceClient == null) {
-                throw new IllegalStateException("Tried to move task to 
caught-up client but none exist");
-            }
-
-            final ClientState sourceClientState = 
clientStates.get(sourceClient);
-            sourceClientState.assignActive(movement.task);
-            clientsByTaskLoad.offer(sourceClient);
+            final UUID standbySourceClient = clientsByTaskLoad.poll(

Review comment:
       This is a really nice touch 👍 Although without attempting some degree of 
stickiness in the standby task assignment it seems unlikely to actually find a 
standby on a caught-up client..




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to