This is an automated email from the ASF dual-hosted git repository.

cadonna pushed a commit to branch 3.2
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/3.2 by this push:
     new ac11383  KAFKA-13600: Kafka Streams - Fall back to most caught up 
client if no caught up clients exist (#11760)
ac11383 is described below

commit ac11383365771895d8ca4bce8fb8f0882621a38f
Author: Tim Patterson <[email protected]>
AuthorDate: Tue Mar 29 03:48:39 2022 +1300

    KAFKA-13600: Kafka Streams - Fall back to most caught up client if no 
caught up clients exist (#11760)
    
    The task assignor is modified to consider the Streams client with the most 
caught up states if no Streams client exists that is caught up, i.e., the lag 
of the states on that client is less than the acceptable recovery lag.
    
    Unit test for case task assignment where no caught up nodes exist.
    Existing unit and integration tests to verify no other behaviour has been 
changed
    
    Co-authored-by: Bruno Cadonna <[email protected]>
    
    Reviewer: Bruno Cadonna <[email protected]>
---
 .../assignment/HighAvailabilityTaskAssignor.java   |  17 ++
 .../internals/assignment/TaskMovement.java         | 178 +++++++++++++++------
 .../HighAvailabilityTaskAssignorTest.java          |  31 ++++
 .../internals/assignment/TaskMovementTest.java     | 144 ++++++++++++++---
 4 files changed, 303 insertions(+), 67 deletions(-)

diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java
index 7111ae2..c54199a 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java
@@ -22,6 +22,7 @@ import 
org.apache.kafka.streams.processor.internals.assignment.AssignorConfigura
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.util.Comparator;
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.Map;
@@ -68,6 +69,8 @@ public class HighAvailabilityTaskAssignor implements 
TaskAssignor {
             configs.acceptableRecoveryLag
         );
 
+        final Map<TaskId, SortedSet<UUID>> tasksToClientByLag = 
tasksToClientByLag(statefulTasks, clientStates);
+
         // We temporarily need to know which standby tasks were intended as 
warmups
         // for active tasks, so that we don't move them (again) when we plan 
standby
         // task movements. We can then immediately treat warmups exactly the 
same as
@@ -77,6 +80,7 @@ public class HighAvailabilityTaskAssignor implements 
TaskAssignor {
 
         final int neededActiveTaskMovements = assignActiveTaskMovements(
             tasksToCaughtUpClients,
+            tasksToClientByLag,
             clientStates,
             warmups,
             remainingWarmupReplicas
@@ -84,6 +88,7 @@ public class HighAvailabilityTaskAssignor implements 
TaskAssignor {
 
         final int neededStandbyTaskMovements = assignStandbyTaskMovements(
             tasksToCaughtUpClients,
+            tasksToClientByLag,
             clientStates,
             remainingWarmupReplicas,
             warmups
@@ -238,6 +243,18 @@ public class HighAvailabilityTaskAssignor implements 
TaskAssignor {
         return taskToCaughtUpClients;
     }
 
+    private static Map<TaskId, SortedSet<UUID>> tasksToClientByLag(final 
Set<TaskId> statefulTasks,
+                                                              final Map<UUID, 
ClientState> clientStates) {
+        final Map<TaskId, SortedSet<UUID>> tasksToClientByLag = new 
HashMap<>();
+        for (final TaskId task : statefulTasks) {
+            final SortedSet<UUID> clientLag = new 
TreeSet<>(Comparator.<UUID>comparingLong(a ->
+                    clientStates.get(a).lagFor(task)).thenComparing(a -> a));
+            clientLag.addAll(clientStates.keySet());
+            tasksToClientByLag.put(task, clientLag);
+        }
+        return tasksToClientByLag;
+    }
+
     private static boolean unbounded(final long acceptableRecoveryLag) {
         return acceptableRecoveryLag == Long.MAX_VALUE;
     }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovement.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovement.java
index cbfa3da..38e6427 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovement.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovement.java
@@ -29,6 +29,7 @@ import java.util.TreeSet;
 import java.util.UUID;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.function.BiFunction;
+import java.util.function.Function;
 
 import static java.util.Arrays.asList;
 import static java.util.Objects.requireNonNull;
@@ -42,10 +43,6 @@ final class TaskMovement {
         this.task = task;
         this.destination = destination;
         this.caughtUpClients = caughtUpClients;
-
-        if (caughtUpClients == null || caughtUpClients.isEmpty()) {
-            throw new IllegalStateException("Should not attempt to move a task 
if no caught up clients exist");
-        }
     }
 
     private TaskId task() {
@@ -56,25 +53,34 @@ final class TaskMovement {
         return caughtUpClients.size();
     }
 
-    private static boolean 
taskIsNotCaughtUpOnClientAndOtherCaughtUpClientsExist(final TaskId task,
-                                                                               
  final UUID client,
-                                                                               
  final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients) {
-        return !taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(task, client, 
tasksToCaughtUpClients);
+    private static boolean 
taskIsNotCaughtUpOnClientAndOtherMoreCaughtUpClientsExist(final TaskId task,
+                                                                               
      final UUID client,
+                                                                               
      final Map<UUID, ClientState> clientStates,
+                                                                               
      final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients,
+                                                                               
      final Map<TaskId, SortedSet<UUID>> tasksToClientByLag) {
+        final SortedSet<UUID> taskClients = 
requireNonNull(tasksToClientByLag.get(task), "uninitialized set");
+        if (taskIsCaughtUpOnClient(task, client, tasksToCaughtUpClients)) {
+            return false;
+        }
+        final long mostCaughtUpLag = 
clientStates.get(taskClients.first()).lagFor(task);
+        final long clientLag = clientStates.get(client).lagFor(task);
+        return mostCaughtUpLag < clientLag;
     }
 
-    private static boolean 
taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(final TaskId task,
-                                                                          
final UUID client,
-                                                                          
final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients) {
+    private static boolean taskIsCaughtUpOnClient(final TaskId task,
+                                                  final UUID client,
+                                                  final Map<TaskId, 
SortedSet<UUID>> tasksToCaughtUpClients) {
         final Set<UUID> caughtUpClients = 
requireNonNull(tasksToCaughtUpClients.get(task), "uninitialized set");
-        return caughtUpClients.isEmpty() || caughtUpClients.contains(client);
+        return caughtUpClients.contains(client);
     }
 
     static int assignActiveTaskMovements(final Map<TaskId, SortedSet<UUID>> 
tasksToCaughtUpClients,
+                                         final Map<TaskId, SortedSet<UUID>> 
tasksToClientByLag,
                                          final Map<UUID, ClientState> 
clientStates,
                                          final Map<UUID, Set<TaskId>> warmups,
                                          final AtomicInteger 
remainingWarmupReplicas) {
         final BiFunction<UUID, TaskId, Boolean> caughtUpPredicate =
-            (client, task) -> 
taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(task, client, 
tasksToCaughtUpClients);
+            (client, task) -> taskIsCaughtUpOnClient(task, client, 
tasksToCaughtUpClients);
 
         final ConstrainedPrioritySet caughtUpClientsByTaskLoad = new 
ConstrainedPrioritySet(
             caughtUpPredicate,
@@ -89,10 +95,10 @@ final class TaskMovement {
             final UUID client = clientStateEntry.getKey();
             final ClientState state = clientStateEntry.getValue();
             for (final TaskId task : state.activeTasks()) {
-                // if the desired client is not caught up, and there is 
another client that _is_ caught up, then
-                // we schedule a movement, so we can move the active task to 
the caught-up client. We'll try to
+                // if the desired client is not caught up, and there is 
another client that _is_ more caught up, then
+                // we schedule a movement, so we can move the active task to a 
more caught-up client. We'll try to
                 // assign a warm-up to the desired client so that we can move 
it later on.
-                if 
(taskIsNotCaughtUpOnClientAndOtherCaughtUpClientsExist(task, client, 
tasksToCaughtUpClients)) {
+                if 
(taskIsNotCaughtUpOnClientAndOtherMoreCaughtUpClientsExist(task, client, 
clientStates, tasksToCaughtUpClients, tasksToClientByLag)) {
                     taskMovements.add(new TaskMovement(task, client, 
tasksToCaughtUpClients.get(task)));
                 }
             }
@@ -102,33 +108,14 @@ final class TaskMovement {
         final int movementsNeeded = taskMovements.size();
 
         for (final TaskMovement movement : taskMovements) {
-            final UUID standbySourceClient = caughtUpClientsByTaskLoad.poll(
-                movement.task,
-                c -> clientStates.get(c).hasStandbyTask(movement.task)
-            );
-            if (standbySourceClient == null) {
-                // there's not a caught-up standby available to take over the 
task, so we'll schedule a warmup instead
-                final UUID sourceClient = requireNonNull(
-                    caughtUpClientsByTaskLoad.poll(movement.task),
-                    "Tried to move task to caught-up client but none exist"
-                );
-
-                moveActiveAndTryToWarmUp(
-                    remainingWarmupReplicas,
-                    movement.task,
-                    clientStates.get(sourceClient),
-                    clientStates.get(movement.destination),
-                    warmups.computeIfAbsent(movement.destination, x -> new 
TreeSet<>())
-                );
-                caughtUpClientsByTaskLoad.offerAll(asList(sourceClient, 
movement.destination));
-            } else {
-                // we found a candidate to trade standby/active state with our 
destination, so we don't need a warmup
-                swapStandbyAndActive(
-                    movement.task,
-                    clientStates.get(standbySourceClient),
-                    clientStates.get(movement.destination)
-                );
-                caughtUpClientsByTaskLoad.offerAll(asList(standbySourceClient, 
movement.destination));
+            // Attempt to find a caught up standby, otherwise find any caught 
up client, failing that use the most
+            // caught up client.
+            final boolean moved = 
tryToSwapStandbyAndActiveOnCaughtUpClient(clientStates, 
caughtUpClientsByTaskLoad, movement) ||
+                    
tryToMoveActiveToCaughtUpClientAndTryToWarmUp(clientStates, warmups, 
remainingWarmupReplicas, caughtUpClientsByTaskLoad, movement) ||
+                    tryToMoveActiveToMostCaughtUpClient(tasksToClientByLag, 
clientStates, warmups, remainingWarmupReplicas, caughtUpClientsByTaskLoad, 
movement);
+
+            if (!moved) {
+                throw new IllegalStateException("Tried to move task to more 
caught-up client as scheduled before but none exist");
             }
         }
 
@@ -136,11 +123,12 @@ final class TaskMovement {
     }
 
     static int assignStandbyTaskMovements(final Map<TaskId, SortedSet<UUID>> 
tasksToCaughtUpClients,
+                                          final Map<TaskId, SortedSet<UUID>> 
tasksToClientByLag,
                                           final Map<UUID, ClientState> 
clientStates,
                                           final AtomicInteger 
remainingWarmupReplicas,
                                           final Map<UUID, Set<TaskId>> 
warmups) {
         final BiFunction<UUID, TaskId, Boolean> caughtUpPredicate =
-            (client, task) -> 
taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(task, client, 
tasksToCaughtUpClients);
+            (client, task) -> taskIsCaughtUpOnClient(task, client, 
tasksToCaughtUpClients);
 
         final ConstrainedPrioritySet caughtUpClientsByTaskLoad = new 
ConstrainedPrioritySet(
             caughtUpPredicate,
@@ -157,8 +145,8 @@ final class TaskMovement {
             for (final TaskId task : state.standbyTasks()) {
                 if (warmups.getOrDefault(destination, 
Collections.emptySet()).contains(task)) {
                     // this is a warmup, so we won't move it.
-                } else if 
(taskIsNotCaughtUpOnClientAndOtherCaughtUpClientsExist(task, destination, 
tasksToCaughtUpClients)) {
-                    // if the desired client is not caught up, and there is 
another client that _is_ caught up, then
+                } else if 
(taskIsNotCaughtUpOnClientAndOtherMoreCaughtUpClientsExist(task, destination, 
clientStates, tasksToCaughtUpClients, tasksToClientByLag)) {
+                    // if the desired client is not caught up, and there is 
another client that _is_ more caught up, then
                     // we schedule a movement, so we can move the active task 
to the caught-up client. We'll try to
                     // assign a warm-up to the desired client so that we can 
move it later on.
                     taskMovements.add(new TaskMovement(task, destination, 
tasksToCaughtUpClients.get(task)));
@@ -170,12 +158,18 @@ final class TaskMovement {
         int movementsNeeded = 0;
 
         for (final TaskMovement movement : taskMovements) {
-            final UUID sourceClient = caughtUpClientsByTaskLoad.poll(
+            final Function<UUID, Boolean> eligibleClientPredicate =
+                    clientId -> 
!clientStates.get(clientId).hasAssignedTask(movement.task);
+            UUID sourceClient = caughtUpClientsByTaskLoad.poll(
                 movement.task,
-                clientId -> 
!clientStates.get(clientId).hasAssignedTask(movement.task)
+                eligibleClientPredicate
             );
 
             if (sourceClient == null) {
+                sourceClient = mostCaughtUpEligibleClient(tasksToClientByLag, 
eligibleClientPredicate, movement.task, movement.destination);
+            }
+
+            if (sourceClient == null) {
                 // then there's no caught-up client that doesn't already have 
a copy of this task, so there's
                 // nowhere to move it.
             } else {
@@ -193,6 +187,74 @@ final class TaskMovement {
         return movementsNeeded;
     }
 
+    private static boolean tryToSwapStandbyAndActiveOnCaughtUpClient(final 
Map<UUID, ClientState> clientStates,
+                                                                     final 
ConstrainedPrioritySet caughtUpClientsByTaskLoad,
+                                                                     final 
TaskMovement movement) {
+        final UUID caughtUpStandbySourceClient = 
caughtUpClientsByTaskLoad.poll(
+                movement.task,
+                c -> clientStates.get(c).hasStandbyTask(movement.task)
+        );
+        if (caughtUpStandbySourceClient != null) {
+            swapStandbyAndActive(
+                    movement.task,
+                    clientStates.get(caughtUpStandbySourceClient),
+                    clientStates.get(movement.destination)
+            );
+            
caughtUpClientsByTaskLoad.offerAll(asList(caughtUpStandbySourceClient, 
movement.destination));
+            return true;
+        }
+        return false;
+    }
+
+    private static boolean tryToMoveActiveToCaughtUpClientAndTryToWarmUp(final 
Map<UUID, ClientState> clientStates,
+                                                                         final 
Map<UUID, Set<TaskId>> warmups,
+                                                                         final 
AtomicInteger remainingWarmupReplicas,
+                                                                         final 
ConstrainedPrioritySet caughtUpClientsByTaskLoad,
+                                                                         final 
TaskMovement movement) {
+        final UUID caughtUpSourceClient = 
caughtUpClientsByTaskLoad.poll(movement.task);
+        if (caughtUpSourceClient != null) {
+            moveActiveAndTryToWarmUp(
+                    remainingWarmupReplicas,
+                    movement.task,
+                    clientStates.get(caughtUpSourceClient),
+                    clientStates.get(movement.destination),
+                    warmups.computeIfAbsent(movement.destination, x -> new 
TreeSet<>())
+            );
+            caughtUpClientsByTaskLoad.offerAll(asList(caughtUpSourceClient, 
movement.destination));
+            return true;
+        }
+        return false;
+    }
+
+    private static boolean tryToMoveActiveToMostCaughtUpClient(final 
Map<TaskId, SortedSet<UUID>> tasksToClientByLag,
+                                                               final Map<UUID, 
ClientState> clientStates,
+                                                               final Map<UUID, 
Set<TaskId>> warmups,
+                                                               final 
AtomicInteger remainingWarmupReplicas,
+                                                               final 
ConstrainedPrioritySet caughtUpClientsByTaskLoad,
+                                                               final 
TaskMovement movement) {
+        final UUID mostCaughtUpSourceClient = 
mostCaughtUpEligibleClient(tasksToClientByLag, movement.task, 
movement.destination);
+        if (mostCaughtUpSourceClient != null) {
+            if 
(clientStates.get(mostCaughtUpSourceClient).hasStandbyTask(movement.task)) {
+                swapStandbyAndActive(
+                        movement.task,
+                        clientStates.get(mostCaughtUpSourceClient),
+                        clientStates.get(movement.destination)
+                );
+            } else {
+                moveActiveAndTryToWarmUp(
+                        remainingWarmupReplicas,
+                        movement.task,
+                        clientStates.get(mostCaughtUpSourceClient),
+                        clientStates.get(movement.destination),
+                        warmups.computeIfAbsent(movement.destination, x -> new 
TreeSet<>())
+                );
+            }
+            
caughtUpClientsByTaskLoad.offerAll(asList(mostCaughtUpSourceClient, 
movement.destination));
+            return true;
+        }
+        return false;
+    }
+
     private static void moveActiveAndTryToWarmUp(final AtomicInteger 
remainingWarmupReplicas,
                                                  final TaskId task,
                                                  final ClientState 
sourceClientState,
@@ -235,4 +297,24 @@ final class TaskMovement {
         destinationClientState.assignStandby(task);
     }
 
+    private static UUID mostCaughtUpEligibleClient(final Map<TaskId, 
SortedSet<UUID>> tasksToClientByLag,
+                                                   final TaskId task,
+                                                   final UUID 
destinationClient) {
+        return mostCaughtUpEligibleClient(tasksToClientByLag, client -> true, 
task, destinationClient);
+    }
+
+    private static UUID mostCaughtUpEligibleClient(final Map<TaskId, 
SortedSet<UUID>> tasksToClientByLag,
+                                                   final Function<UUID, 
Boolean> constraint,
+                                                   final TaskId task,
+                                                   final UUID 
destinationClient) {
+        for (final UUID client : tasksToClientByLag.get(task)) {
+            if (destinationClient.equals(client)) {
+                break;
+            } else if (constraint.apply(client)) {
+                return client;
+            }
+        }
+        return null;
+    }
+
 }
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java
index 36ae42f..90e0fed 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java
@@ -420,6 +420,37 @@ public class HighAvailabilityTaskAssignorTest {
     }
 
     @Test
+    public void 
shouldAssignToMostCaughtUpIfActiveTasksWasNotOnCaughtUpClient() {
+        final Set<TaskId> allTasks = mkSet(TASK_0_0);
+        final Set<TaskId> statefulTasks = mkSet(TASK_0_0);
+        final ClientState client1 = new ClientState(emptySet(), emptySet(), 
singletonMap(TASK_0_0, Long.MAX_VALUE), EMPTY_CLIENT_TAGS, 1);
+        final ClientState client2 = new ClientState(emptySet(), emptySet(), 
singletonMap(TASK_0_0, 1000L), EMPTY_CLIENT_TAGS, 1);
+        final ClientState client3 = new ClientState(emptySet(), emptySet(), 
singletonMap(TASK_0_0, 500L), EMPTY_CLIENT_TAGS, 1);
+        final Map<UUID, ClientState> clientStates = mkMap(
+                mkEntry(UUID_1, client1),
+                mkEntry(UUID_2, client2),
+                mkEntry(UUID_3, client3)
+        );
+
+        final boolean probingRebalanceNeeded =
+                new HighAvailabilityTaskAssignor().assign(clientStates, 
allTasks, statefulTasks, configWithStandbys);
+
+        assertThat(clientStates.get(UUID_1).activeTasks(), is(emptySet()));
+        assertThat(clientStates.get(UUID_2).activeTasks(), is(emptySet()));
+        assertThat(clientStates.get(UUID_3).activeTasks(), 
is(singleton(TASK_0_0)));
+
+        assertThat(clientStates.get(UUID_1).standbyTasks(), 
is(singleton(TASK_0_0))); // warm up
+        assertThat(clientStates.get(UUID_2).standbyTasks(), 
is(singleton(TASK_0_0))); // standby
+        assertThat(clientStates.get(UUID_3).standbyTasks(), is(emptySet()));
+
+        assertThat(probingRebalanceNeeded, is(true));
+        assertValidAssignment(1, 1, allTasks, emptySet(), clientStates, new 
StringBuilder());
+        assertBalancedActiveAssignment(clientStates, new StringBuilder());
+        assertBalancedStatefulAssignment(allTasks, clientStates, new 
StringBuilder());
+        assertBalancedTasks(clientStates);
+    }
+
+    @Test
     public void shouldAssignStandbysForStatefulTasks() {
         final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1);
         final Set<TaskId> statefulTasks = mkSet(TASK_0_0, TASK_0_1);
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovementTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovementTest.java
index 9b58d18..baf6d18 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovementTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovementTest.java
@@ -19,19 +19,21 @@ package 
org.apache.kafka.streams.processor.internals.assignment;
 import org.apache.kafka.streams.processor.TaskId;
 import org.junit.Test;
 
-import java.util.Collection;
+import java.util.Comparator;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.SortedSet;
 import java.util.TreeMap;
+import java.util.TreeSet;
 import java.util.UUID;
 import java.util.concurrent.atomic.AtomicInteger;
 
 import static java.util.Arrays.asList;
-import static java.util.Collections.emptyList;
+import static java.util.Collections.emptyMap;
+import static java.util.Collections.emptySet;
 import static java.util.Collections.emptySortedSet;
-import static java.util.Collections.singletonList;
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.common.utils.Utils.mkSet;
@@ -58,17 +60,20 @@ public class TaskMovementTest {
         final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, 
TASK_1_0, TASK_1_1, TASK_1_2);
 
         final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = new 
HashMap<>();
+        final Map<TaskId, SortedSet<UUID>> tasksToClientByLag = new 
HashMap<>();
         for (final TaskId task : allTasks) {
             tasksToCaughtUpClients.put(task, mkSortedSet(UUID_1, UUID_2, 
UUID_3));
+            tasksToClientByLag.put(task, mkOrderedSet(UUID_1, UUID_2, UUID_3));
         }
 
-        final ClientState client1 = 
getClientStateWithActiveAssignment(asList(TASK_0_0, TASK_1_0));
-        final ClientState client2 = 
getClientStateWithActiveAssignment(asList(TASK_0_1, TASK_1_1));
-        final ClientState client3 = 
getClientStateWithActiveAssignment(asList(TASK_0_2, TASK_1_2));
+        final ClientState client1 = 
getClientStateWithActiveAssignment(mkSet(TASK_0_0, TASK_1_0), allTasks, 
allTasks);
+        final ClientState client2 = 
getClientStateWithActiveAssignment(mkSet(TASK_0_1, TASK_1_1), allTasks, 
allTasks);
+        final ClientState client3 = 
getClientStateWithActiveAssignment(mkSet(TASK_0_2, TASK_1_2), allTasks, 
allTasks);
 
         assertThat(
             assignActiveTaskMovements(
                 tasksToCaughtUpClients,
+                tasksToClientByLag,
                 getClientStatesMap(client1, client2, client3),
                 new TreeMap<>(),
                 new AtomicInteger(maxWarmupReplicas)
@@ -80,10 +85,11 @@ public class TaskMovementTest {
     @Test
     public void 
shouldAssignAllTasksToClientsAndReturnFalseIfNoClientsAreCaughtUp() {
         final int maxWarmupReplicas = Integer.MAX_VALUE;
+        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, 
TASK_1_0, TASK_1_1, TASK_1_2);
 
-        final ClientState client1 = 
getClientStateWithActiveAssignment(asList(TASK_0_0, TASK_1_0));
-        final ClientState client2 = 
getClientStateWithActiveAssignment(asList(TASK_0_1, TASK_1_1));
-        final ClientState client3 = 
getClientStateWithActiveAssignment(asList(TASK_0_2, TASK_1_2));
+        final ClientState client1 = 
getClientStateWithActiveAssignment(mkSet(TASK_0_0, TASK_1_0), mkSet(), 
allTasks);
+        final ClientState client2 = 
getClientStateWithActiveAssignment(mkSet(TASK_0_1, TASK_1_1), mkSet(), 
allTasks);
+        final ClientState client3 = 
getClientStateWithActiveAssignment(mkSet(TASK_0_2, TASK_1_2), mkSet(), 
allTasks);
 
         final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = mkMap(
             mkEntry(TASK_0_0, emptySortedSet()),
@@ -93,9 +99,18 @@ public class TaskMovementTest {
             mkEntry(TASK_1_1, emptySortedSet()),
             mkEntry(TASK_1_2, emptySortedSet())
         );
+        final Map<TaskId, SortedSet<UUID>> tasksToClientByLag = mkMap(
+            mkEntry(TASK_0_0, mkOrderedSet(UUID_1, UUID_2, UUID_3)),
+            mkEntry(TASK_0_1, mkOrderedSet(UUID_1, UUID_2, UUID_3)),
+            mkEntry(TASK_0_2, mkOrderedSet(UUID_1, UUID_2, UUID_3)),
+            mkEntry(TASK_1_0, mkOrderedSet(UUID_1, UUID_2, UUID_3)),
+            mkEntry(TASK_1_1, mkOrderedSet(UUID_1, UUID_2, UUID_3)),
+            mkEntry(TASK_1_2, mkOrderedSet(UUID_1, UUID_2, UUID_3))
+        );
         assertThat(
             assignActiveTaskMovements(
                 tasksToCaughtUpClients,
+                tasksToClientByLag,
                 getClientStatesMap(client1, client2, client3),
                 new TreeMap<>(),
                 new AtomicInteger(maxWarmupReplicas)
@@ -107,9 +122,10 @@ public class TaskMovementTest {
     @Test
     public void 
shouldMoveTasksToCaughtUpClientsAndAssignWarmupReplicasInTheirPlace() {
         final int maxWarmupReplicas = Integer.MAX_VALUE;
-        final ClientState client1 = 
getClientStateWithActiveAssignment(singletonList(TASK_0_0));
-        final ClientState client2 = 
getClientStateWithActiveAssignment(singletonList(TASK_0_1));
-        final ClientState client3 = 
getClientStateWithActiveAssignment(singletonList(TASK_0_2));
+        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2);
+        final ClientState client1 = 
getClientStateWithActiveAssignment(mkSet(TASK_0_0), mkSet(TASK_0_0), allTasks);
+        final ClientState client2 = 
getClientStateWithActiveAssignment(mkSet(TASK_0_1), mkSet(TASK_0_2), allTasks);
+        final ClientState client3 = 
getClientStateWithActiveAssignment(mkSet(TASK_0_2), mkSet(TASK_0_1), allTasks);
         final Map<UUID, ClientState> clientStates = 
getClientStatesMap(client1, client2, client3);
 
         final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = mkMap(
@@ -117,11 +133,17 @@ public class TaskMovementTest {
             mkEntry(TASK_0_1, mkSortedSet(UUID_3)),
             mkEntry(TASK_0_2, mkSortedSet(UUID_2))
         );
+        final Map<TaskId, SortedSet<UUID>> tasksToClientByLag = mkMap(
+            mkEntry(TASK_0_0, mkOrderedSet(UUID_1, UUID_2, UUID_3)),
+            mkEntry(TASK_0_1, mkOrderedSet(UUID_3, UUID_1, UUID_2)),
+            mkEntry(TASK_0_2, mkOrderedSet(UUID_2, UUID_1, UUID_3))
+        );
 
         assertThat(
             "should have assigned movements",
             assignActiveTaskMovements(
                 tasksToCaughtUpClients,
+                tasksToClientByLag,
                 clientStates,
                 new TreeMap<>(),
                 new AtomicInteger(maxWarmupReplicas)
@@ -140,11 +162,59 @@ public class TaskMovementTest {
     }
 
     @Test
+    public void 
shouldMoveTasksToMostCaughtUpClientsAndAssignWarmupReplicasInTheirPlace() {
+        final int maxWarmupReplicas = Integer.MAX_VALUE;
+        final Map<TaskId, Long> client1Lags = mkMap(mkEntry(TASK_0_0, 10000L), 
mkEntry(TASK_0_1, 20000L), mkEntry(TASK_0_2, 30000L));
+        final Map<TaskId, Long> client2Lags = mkMap(mkEntry(TASK_0_2, 10000L), 
mkEntry(TASK_0_0, 20000L), mkEntry(TASK_0_1, 30000L));
+        final Map<TaskId, Long> client3Lags = mkMap(mkEntry(TASK_0_1, 10000L), 
mkEntry(TASK_0_2, 20000L), mkEntry(TASK_0_0, 30000L));
+
+        final ClientState client1 = getClientStateWithLags(mkSet(TASK_0_0), 
client1Lags);
+        final ClientState client2 = getClientStateWithLags(mkSet(TASK_0_1), 
client2Lags);
+        final ClientState client3 = getClientStateWithLags(mkSet(TASK_0_2), 
client3Lags);
+        // To test when the task is already a standby on the most caught up 
node
+        client3.assignStandby(TASK_0_1);
+        final Map<UUID, ClientState> clientStates = 
getClientStatesMap(client1, client2, client3);
+
+        final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = mkMap(
+                mkEntry(TASK_0_0, mkSortedSet()),
+                mkEntry(TASK_0_1, mkSortedSet()),
+                mkEntry(TASK_0_2, mkSortedSet())
+        );
+        final Map<TaskId, SortedSet<UUID>> tasksToClientByLag = mkMap(
+                mkEntry(TASK_0_0, mkOrderedSet(UUID_1, UUID_2, UUID_3)),
+                mkEntry(TASK_0_1, mkOrderedSet(UUID_3, UUID_1, UUID_2)),
+                mkEntry(TASK_0_2, mkOrderedSet(UUID_2, UUID_3, UUID_1))
+        );
+
+        assertThat(
+                "should have assigned movements",
+                assignActiveTaskMovements(
+                        tasksToCaughtUpClients,
+                        tasksToClientByLag,
+                        clientStates,
+                        new TreeMap<>(),
+                        new AtomicInteger(maxWarmupReplicas)
+                ),
+                is(2)
+        );
+        // The active tasks have changed to the ones that each client is most 
caught up on
+        assertThat(client1, hasProperty("activeTasks", 
ClientState::activeTasks, mkSet(TASK_0_0)));
+        assertThat(client2, hasProperty("activeTasks", 
ClientState::activeTasks, mkSet(TASK_0_2)));
+        assertThat(client3, hasProperty("activeTasks", 
ClientState::activeTasks, mkSet(TASK_0_1)));
+
+        // we assigned warmups to migrate to the input active assignment
+        assertThat(client1, hasProperty("standbyTasks", 
ClientState::standbyTasks, mkSet()));
+        assertThat(client2, hasProperty("standbyTasks", 
ClientState::standbyTasks, mkSet(TASK_0_1)));
+        assertThat(client3, hasProperty("standbyTasks", 
ClientState::standbyTasks, mkSet(TASK_0_2)));
+    }
+
+    @Test
     public void shouldOnlyGetUpToMaxWarmupReplicasAndReturnTrue() {
         final int maxWarmupReplicas = 1;
-        final ClientState client1 = 
getClientStateWithActiveAssignment(singletonList(TASK_0_0));
-        final ClientState client2 = 
getClientStateWithActiveAssignment(singletonList(TASK_0_1));
-        final ClientState client3 = 
getClientStateWithActiveAssignment(singletonList(TASK_0_2));
+        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2);
+        final ClientState client1 = 
getClientStateWithActiveAssignment(mkSet(TASK_0_0), mkSet(TASK_0_0), allTasks);
+        final ClientState client2 = 
getClientStateWithActiveAssignment(mkSet(TASK_0_1), mkSet(TASK_0_2), allTasks);
+        final ClientState client3 = 
getClientStateWithActiveAssignment(mkSet(TASK_0_2), mkSet(TASK_0_1), allTasks);
         final Map<UUID, ClientState> clientStates = 
getClientStatesMap(client1, client2, client3);
 
         final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = mkMap(
@@ -152,11 +222,17 @@ public class TaskMovementTest {
             mkEntry(TASK_0_1, mkSortedSet(UUID_3)),
             mkEntry(TASK_0_2, mkSortedSet(UUID_2))
         );
+        final Map<TaskId, SortedSet<UUID>> tasksToClientByLag = mkMap(
+            mkEntry(TASK_0_0, mkOrderedSet(UUID_1, UUID_2, UUID_3)),
+            mkEntry(TASK_0_1, mkOrderedSet(UUID_3, UUID_1, UUID_2)),
+            mkEntry(TASK_0_2, mkOrderedSet(UUID_2, UUID_1, UUID_3))
+        );
 
         assertThat(
             "should have assigned movements",
             assignActiveTaskMovements(
                 tasksToCaughtUpClients,
+                tasksToClientByLag,
                 clientStates,
                 new TreeMap<>(),
                 new AtomicInteger(maxWarmupReplicas)
@@ -182,19 +258,24 @@ public class TaskMovementTest {
     @Test
     public void shouldNotCountPreviousStandbyTasksTowardsMaxWarmupReplicas() {
         final int maxWarmupReplicas = 0;
-        final ClientState client1 = 
getClientStateWithActiveAssignment(emptyList());
+        final Set<TaskId> allTasks = mkSet(TASK_0_0);
+        final ClientState client1 = 
getClientStateWithActiveAssignment(mkSet(), mkSet(TASK_0_0), allTasks);
         client1.assignStandby(TASK_0_0);
-        final ClientState client2 = 
getClientStateWithActiveAssignment(singletonList(TASK_0_0));
+        final ClientState client2 = 
getClientStateWithActiveAssignment(mkSet(TASK_0_0), mkSet(), allTasks);
         final Map<UUID, ClientState> clientStates = 
getClientStatesMap(client1, client2);
 
         final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = mkMap(
             mkEntry(TASK_0_0, mkSortedSet(UUID_1))
         );
+        final Map<TaskId, SortedSet<UUID>> tasksToClientByLag = mkMap(
+            mkEntry(TASK_0_0, mkOrderedSet(UUID_1, UUID_2))
+        );
 
         assertThat(
             "should have assigned movements",
             assignActiveTaskMovements(
                 tasksToCaughtUpClients,
+                tasksToClientByLag,
                 clientStates,
                 new TreeMap<>(),
                 new AtomicInteger(maxWarmupReplicas)
@@ -215,10 +296,35 @@ public class TaskMovementTest {
 
     }
 
-    private static ClientState getClientStateWithActiveAssignment(final 
Collection<TaskId> activeTasks) {
-        final ClientState client1 = new ClientState(1);
+    private static ClientState getClientStateWithActiveAssignment(final 
Set<TaskId> activeTasks,
+                                                                  final 
Set<TaskId> caughtUpTasks,
+                                                                  final 
Set<TaskId> allTasks) {
+        final Map<TaskId, Long> lags = new HashMap<>();
+        for (final TaskId task : allTasks) {
+            if (caughtUpTasks.contains(task)) {
+                lags.put(task, 0L);
+            } else {
+                lags.put(task, 10000L);
+            }
+        }
+        return getClientStateWithLags(activeTasks, lags);
+    }
+
+    private static ClientState getClientStateWithLags(final Set<TaskId> 
activeTasks,
+                                                      final Map<TaskId, Long> 
taskLags) {
+        final ClientState client1 = new ClientState(activeTasks, emptySet(), 
taskLags, emptyMap(), 1);
         client1.assignActiveTasks(activeTasks);
         return client1;
     }
 
+    /**
+     * Creates a SortedSet with the sort order being the order of elements in 
the parameter list
+     */
+    private static SortedSet<UUID> mkOrderedSet(final UUID... clients) {
+        final List<UUID> clientList = asList(clients);
+        final SortedSet<UUID> set = new 
TreeSet<>(Comparator.comparing(clientList::indexOf));
+        set.addAll(clientList);
+        return set;
+    }
+
 }

Reply via email to