yashmayya commented on code in PR #15617:
URL: https://github.com/apache/pinot/pull/15617#discussion_r2079510950


##########
pinot-controller/src/main/java/org/apache/pinot/controller/api/resources/PinotTableRestletResource.java:
##########
@@ -625,6 +625,13 @@ public RebalanceResult rebalance(
           + "more servers.") @DefaultValue("false") @QueryParam("lowDiskMode") 
boolean lowDiskMode,
       @ApiParam(value = "Whether to use best-efforts to rebalance (not fail 
the rebalance when the no-downtime "
           + "contract cannot be achieved)") @DefaultValue("false") 
@QueryParam("bestEfforts") boolean bestEfforts,
+      @ApiParam(value = "How many maximum segment adds per server to update in 
the IdealState in each step. For "
+          + "non-strict replica group based assignment, this number will be 
the closest possible without splitting up "
+          + "a single segment's step's replicas across steps (so some servers 
may get fewer segments). For strict "
+          + "replica group based assignment, this is a per-server best effort 
value since each partition of a replica "
+          + "group must be moved as a whole and at least one partition in a 
replica group should be moved. A value of "
+          + "-1 is used to disable batching (unlimited segments).")

Review Comment:
   > unlimited segments
   
   Might be useful to explicitly clarify that this is "per incremental step in 
the rebalance", while keeping the min available replicas invariant intact so 
that users don't get confused about the batching mechanism's semantics.



##########
pinot-controller/src/main/java/org/apache/pinot/controller/api/resources/PinotTableRestletResource.java:
##########
@@ -625,6 +625,13 @@ public RebalanceResult rebalance(
           + "more servers.") @DefaultValue("false") @QueryParam("lowDiskMode") 
boolean lowDiskMode,
       @ApiParam(value = "Whether to use best-efforts to rebalance (not fail 
the rebalance when the no-downtime "
           + "contract cannot be achieved)") @DefaultValue("false") 
@QueryParam("bestEfforts") boolean bestEfforts,
+      @ApiParam(value = "How many maximum segment adds per server to update in 
the IdealState in each step. For "
+          + "non-strict replica group based assignment, this number will be 
the closest possible without splitting up "
+          + "a single segment's step's replicas across steps (so some servers 
may get fewer segments). For strict "

Review Comment:
   > a single segment's step's replicas across steps
   
   nit: this wording is a little confusing



##########
pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/TableRebalancer.java:
##########
@@ -1524,67 +1535,336 @@ private static void handleErrorInstance(String 
tableNameWithType, String segment
     }
   }
 
+  /**
+   * Uses the default LOGGER
+   */
+  @VisibleForTesting
+  static Map<String, Map<String, String>> getNextAssignment(Map<String, 
Map<String, String>> currentAssignment,
+      Map<String, Map<String, String>> targetAssignment, int 
minAvailableReplicas, boolean enableStrictReplicaGroup,
+      boolean lowDiskMode, int batchSizePerServer, 
Object2IntOpenHashMap<String> segmentPartitionIdMap,
+      PartitionIdFetcher partitionIdFetcher, boolean 
isStrictRealtimeSegmentAssignment) {
+    return getNextAssignment(currentAssignment, targetAssignment, 
minAvailableReplicas, enableStrictReplicaGroup,
+        lowDiskMode, batchSizePerServer, segmentPartitionIdMap, 
partitionIdFetcher, isStrictRealtimeSegmentAssignment,
+        LOGGER);
+  }
+
   /**
    * Returns the next assignment for the table based on the current assignment 
and the target assignment with regard to
    * the minimum available replicas requirement. For strict replica-group 
mode, track the available instances for all
    * the segments with the same instances in the next assignment, and ensure 
the minimum available replicas requirement
    * is met. If adding the assignment for a segment breaks the requirement, 
use the current assignment for the segment.
+   *
+   * For strict replica group routing only (where the segment assignment is 
not StrictRealtimeSegmentAssignment)
+   * if batching is enabled, don't group the assignment by partitionId, since 
the segments of the same partitionId do
+   * not need to be assigned to the same servers. For strict replica group 
routing with strict replica group
+   * assignment on the other hand, group the assignment by partitionId since a 
partition must move as a whole, and they
+   * have the same servers assigned across all segments belonging to the same 
partitionId.

Review Comment:
   `StrictRealtimeSegmentAssignment` tries to ensure that all segments in a 
partition (`CONSUMING` and `COMPLETED`) are assigned to a _single_ instance in 
each replica group right?



##########
pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/TableRebalancer.java:
##########
@@ -1524,67 +1535,336 @@ private static void handleErrorInstance(String 
tableNameWithType, String segment
     }
   }
 
+  /**
+   * Uses the default LOGGER
+   */
+  @VisibleForTesting
+  static Map<String, Map<String, String>> getNextAssignment(Map<String, 
Map<String, String>> currentAssignment,
+      Map<String, Map<String, String>> targetAssignment, int 
minAvailableReplicas, boolean enableStrictReplicaGroup,
+      boolean lowDiskMode, int batchSizePerServer, 
Object2IntOpenHashMap<String> segmentPartitionIdMap,
+      PartitionIdFetcher partitionIdFetcher, boolean 
isStrictRealtimeSegmentAssignment) {
+    return getNextAssignment(currentAssignment, targetAssignment, 
minAvailableReplicas, enableStrictReplicaGroup,
+        lowDiskMode, batchSizePerServer, segmentPartitionIdMap, 
partitionIdFetcher, isStrictRealtimeSegmentAssignment,
+        LOGGER);
+  }
+
   /**
    * Returns the next assignment for the table based on the current assignment 
and the target assignment with regard to
    * the minimum available replicas requirement. For strict replica-group 
mode, track the available instances for all
    * the segments with the same instances in the next assignment, and ensure 
the minimum available replicas requirement
    * is met. If adding the assignment for a segment breaks the requirement, 
use the current assignment for the segment.
+   *
+   * For strict replica group routing only (where the segment assignment is 
not StrictRealtimeSegmentAssignment)
+   * if batching is enabled, don't group the assignment by partitionId, since 
the segments of the same partitionId do
+   * not need to be assigned to the same servers. For strict replica group 
routing with strict replica group
+   * assignment on the other hand, group the assignment by partitionId since a 
partition must move as a whole, and they
+   * have the same servers assigned across all segments belonging to the same 
partitionId.
+   *
+   * TODO: Ideally if strict replica group routing is enabled then 
StrictRealtimeSegmentAssignment should be used, but

Review Comment:
   ```suggestion
      * <p>
      * For strict replica group routing only (where the segment assignment is 
not StrictRealtimeSegmentAssignment)
      * if batching is enabled, don't group the assignment by partitionId, 
since the segments of the same partitionId do
      * not need to be assigned to the same servers. For strict replica group 
routing with strict replica group
      * assignment on the other hand, group the assignment by partitionId since 
a partition must move as a whole, and they
      * have the same servers assigned across all segments belonging to the 
same partitionId.
      * <p>
      * TODO: Ideally if strict replica group routing is enabled then 
StrictRealtimeSegmentAssignment should be used, but
   ```
   nit: it's difficult to read when rendered currently



##########
pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/TableRebalancer.java:
##########
@@ -1595,13 +1875,41 @@ private static Map<String, Map<String, String>> 
getNextNonStrictReplicaGroupAssi
       Map<String, String> nextInstanceStateMap =
           getNextSingleSegmentAssignment(currentInstanceStateMap, 
targetInstanceStateMap, minAvailableReplicas,
               lowDiskMode, numSegmentsToOffloadMap, 
assignmentMap)._instanceStateMap;
-      nextAssignment.put(segmentName, nextInstanceStateMap);
-      updateNumSegmentsToOffloadMap(numSegmentsToOffloadMap, 
currentInstanceStateMap.keySet(),
-          nextInstanceStateMap.keySet());
+      Set<String> serversAddedForSegment = 
getServersAddedInSingleSegmentAssignment(currentInstanceStateMap,
+          nextInstanceStateMap);
+      boolean anyServerExhaustedBatchSize = false;
+      if (batchSizePerServer != RebalanceConfig.DISABLE_BATCH_SIZE_PER_SERVER) 
{
+        for (String server : serversAddedForSegment) {
+          if (serverToNumSegmentsAddedSoFar.getOrDefault(server, 0) >= 
batchSizePerServer) {
+            anyServerExhaustedBatchSize = true;
+            break;
+          }
+        }
+      }
+      if (anyServerExhaustedBatchSize) {
+        // Exhausted the batch size for at least 1 server, set to existing 
assignment
+        nextAssignment.put(segmentName, currentInstanceStateMap);
+      } else {
+        // Add the next assignment and update the segments added so far counts
+        for (String server : serversAddedForSegment) {
+          int numSegmentsAdded = 
serverToNumSegmentsAddedSoFar.getOrDefault(server, 0);
+          serverToNumSegmentsAddedSoFar.put(server, numSegmentsAdded + 1);
+        }

Review Comment:
   ```suggestion
           serversAddedForSegment.forEach(server -> 
serverToNumSegmentsAddedSoFar.merge(server, 1, Integer::sum));
   ```
   nit: optional suggestion



##########
pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/TableRebalancer.java:
##########
@@ -1595,13 +1875,41 @@ private static Map<String, Map<String, String>> 
getNextNonStrictReplicaGroupAssi
       Map<String, String> nextInstanceStateMap =
           getNextSingleSegmentAssignment(currentInstanceStateMap, 
targetInstanceStateMap, minAvailableReplicas,
               lowDiskMode, numSegmentsToOffloadMap, 
assignmentMap)._instanceStateMap;
-      nextAssignment.put(segmentName, nextInstanceStateMap);
-      updateNumSegmentsToOffloadMap(numSegmentsToOffloadMap, 
currentInstanceStateMap.keySet(),
-          nextInstanceStateMap.keySet());
+      Set<String> serversAddedForSegment = 
getServersAddedInSingleSegmentAssignment(currentInstanceStateMap,

Review Comment:
   So we're defining the batch (in batch size) as the number of segments being 
added across all instances without counting any segment removals, correct?



##########
pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/TableRebalancer.java:
##########
@@ -1524,67 +1535,336 @@ private static void handleErrorInstance(String 
tableNameWithType, String segment
     }
   }
 
+  /**
+   * Uses the default LOGGER
+   */
+  @VisibleForTesting
+  static Map<String, Map<String, String>> getNextAssignment(Map<String, 
Map<String, String>> currentAssignment,
+      Map<String, Map<String, String>> targetAssignment, int 
minAvailableReplicas, boolean enableStrictReplicaGroup,
+      boolean lowDiskMode, int batchSizePerServer, 
Object2IntOpenHashMap<String> segmentPartitionIdMap,
+      PartitionIdFetcher partitionIdFetcher, boolean 
isStrictRealtimeSegmentAssignment) {
+    return getNextAssignment(currentAssignment, targetAssignment, 
minAvailableReplicas, enableStrictReplicaGroup,
+        lowDiskMode, batchSizePerServer, segmentPartitionIdMap, 
partitionIdFetcher, isStrictRealtimeSegmentAssignment,
+        LOGGER);
+  }
+
   /**
    * Returns the next assignment for the table based on the current assignment 
and the target assignment with regard to
    * the minimum available replicas requirement. For strict replica-group 
mode, track the available instances for all
    * the segments with the same instances in the next assignment, and ensure 
the minimum available replicas requirement
    * is met. If adding the assignment for a segment breaks the requirement, 
use the current assignment for the segment.
+   *
+   * For strict replica group routing only (where the segment assignment is 
not StrictRealtimeSegmentAssignment)
+   * if batching is enabled, don't group the assignment by partitionId, since 
the segments of the same partitionId do
+   * not need to be assigned to the same servers. For strict replica group 
routing with strict replica group
+   * assignment on the other hand, group the assignment by partitionId since a 
partition must move as a whole, and they
+   * have the same servers assigned across all segments belonging to the same 
partitionId.
+   *
+   * TODO: Ideally if strict replica group routing is enabled then 
StrictRealtimeSegmentAssignment should be used, but
+   *       this is not enforced in the code today. Once enforcement is added, 
there will no longer be any need to
+   *       handle strict replica group routing only v.s. strict replica group 
routing + assignment. Remove the
+   *       getNextStrictReplicaGroupRoutingOnlyAssignment() function.
    */
-  @VisibleForTesting
-  static Map<String, Map<String, String>> getNextAssignment(Map<String, 
Map<String, String>> currentAssignment,
+  private static Map<String, Map<String, String>> 
getNextAssignment(Map<String, Map<String, String>> currentAssignment,
       Map<String, Map<String, String>> targetAssignment, int 
minAvailableReplicas, boolean enableStrictReplicaGroup,
-      boolean lowDiskMode) {
-    return enableStrictReplicaGroup ? 
getNextStrictReplicaGroupAssignment(currentAssignment, targetAssignment,
-        minAvailableReplicas, lowDiskMode)
-        : getNextNonStrictReplicaGroupAssignment(currentAssignment, 
targetAssignment, minAvailableReplicas,
-            lowDiskMode);
+      boolean lowDiskMode, int batchSizePerServer, 
Object2IntOpenHashMap<String> segmentPartitionIdMap,
+      PartitionIdFetcher partitionIdFetcher, boolean 
isStrictRealtimeSegmentAssignment, Logger tableRebalanceLogger) {
+    return (enableStrictReplicaGroup && isStrictRealtimeSegmentAssignment)
+        ? getNextStrictReplicaGroupAssignment(currentAssignment, 
targetAssignment, minAvailableReplicas, lowDiskMode,
+        batchSizePerServer, segmentPartitionIdMap, partitionIdFetcher, 
tableRebalanceLogger)
+        : enableStrictReplicaGroup
+            ? 
getNextStrictReplicaGroupRoutingOnlyAssignment(currentAssignment, 
targetAssignment, minAvailableReplicas,
+            lowDiskMode, batchSizePerServer, segmentPartitionIdMap, 
partitionIdFetcher, tableRebalanceLogger)
+            : getNextNonStrictReplicaGroupAssignment(currentAssignment, 
targetAssignment, minAvailableReplicas,
+            lowDiskMode, batchSizePerServer);
   }
 
   private static Map<String, Map<String, String>> 
getNextStrictReplicaGroupAssignment(
       Map<String, Map<String, String>> currentAssignment, Map<String, 
Map<String, String>> targetAssignment,
-      int minAvailableReplicas, boolean lowDiskMode) {
+      int minAvailableReplicas, boolean lowDiskMode, int batchSizePerServer,
+      Object2IntOpenHashMap<String> segmentPartitionIdMap, PartitionIdFetcher 
partitionIdFetcher,
+      Logger tableRebalanceLogger) {
     Map<String, Map<String, String>> nextAssignment = new TreeMap<>();
     Map<String, Integer> numSegmentsToOffloadMap = 
getNumSegmentsToOffloadMap(currentAssignment, targetAssignment);
+    Map<Integer, Map<String, Map<String, String>>> 
partitionIdToCurrentAssignmentMap;
+    if (batchSizePerServer == RebalanceConfig.DISABLE_BATCH_SIZE_PER_SERVER) {
+      // Don't calculate the partition id to current assignment mapping if 
batching is disabled since
+      // we want to update the next assignment based on all partitions in this 
case
+      partitionIdToCurrentAssignmentMap = new TreeMap<>();
+      partitionIdToCurrentAssignmentMap.put(0, currentAssignment);
+    } else {
+      partitionIdToCurrentAssignmentMap =
+          getPartitionIdToCurrentAssignmentMap(currentAssignment, 
segmentPartitionIdMap, partitionIdFetcher);
+    }
     Map<Pair<Set<String>, Set<String>>, Set<String>> assignmentMap = new 
HashMap<>();
     Map<Set<String>, Set<String>> availableInstancesMap = new HashMap<>();
-    for (Map.Entry<String, Map<String, String>> entry : 
currentAssignment.entrySet()) {
-      String segmentName = entry.getKey();
-      Map<String, String> currentInstanceStateMap = entry.getValue();
-      Map<String, String> targetInstanceStateMap = 
targetAssignment.get(segmentName);
-      SingleSegmentAssignment assignment =
-          getNextSingleSegmentAssignment(currentInstanceStateMap, 
targetInstanceStateMap, minAvailableReplicas,
-              lowDiskMode, numSegmentsToOffloadMap, assignmentMap);
-      Set<String> assignedInstances = assignment._instanceStateMap.keySet();
-      Set<String> availableInstances = assignment._availableInstances;
-      availableInstancesMap.compute(assignedInstances, (k, 
currentAvailableInstances) -> {
-        if (currentAvailableInstances == null) {
-          // First segment assigned to these instances, use the new assignment 
and update the available instances
-          nextAssignment.put(segmentName, assignment._instanceStateMap);
-          updateNumSegmentsToOffloadMap(numSegmentsToOffloadMap, 
currentInstanceStateMap.keySet(), k);
-          return availableInstances;
-        } else {
-          // There are other segments assigned to the same instances, check 
the available instances to see if adding the
-          // new assignment can still hold the minimum available replicas 
requirement
-          availableInstances.retainAll(currentAvailableInstances);
-          if (availableInstances.size() >= minAvailableReplicas) {
-            // New assignment can be added
+
+    Map<String, Integer> serverToNumSegmentsAddedSoFar = new HashMap<>();
+    for (Map<String, Map<String, String>> curAssignment : 
partitionIdToCurrentAssignmentMap.values()) {
+      boolean anyServerExhaustedBatchSize = false;
+      if (batchSizePerServer != RebalanceConfig.DISABLE_BATCH_SIZE_PER_SERVER) 
{
+        Map.Entry<String, Map<String, String>> firstEntry = 
curAssignment.entrySet().iterator().next();
+        // Each partition should be assigned to the same set of servers so it 
is enough to check for whether any server
+        // for one segment is above the limit or not
+        Map<String, String> firstEntryInstanceStateMap = firstEntry.getValue();
+        SingleSegmentAssignment firstAssignment =
+            getNextSingleSegmentAssignment(firstEntryInstanceStateMap, 
targetAssignment.get(firstEntry.getKey()),
+                minAvailableReplicas, lowDiskMode, numSegmentsToOffloadMap, 
assignmentMap);
+        Set<String> serversAdded = 
getServersAddedInSingleSegmentAssignment(firstEntryInstanceStateMap,
+            firstAssignment._instanceStateMap);
+        for (String server : serversAdded) {
+          if (serverToNumSegmentsAddedSoFar.getOrDefault(server, 0) >= 
batchSizePerServer) {
+            anyServerExhaustedBatchSize = true;
+            break;
+          }
+        }
+      }
+      getNextAssignmentForPartitionIdStrictReplicaGroup(curAssignment, 
targetAssignment, nextAssignment,
+          anyServerExhaustedBatchSize, minAvailableReplicas, lowDiskMode, 
numSegmentsToOffloadMap, assignmentMap,
+          availableInstancesMap, serverToNumSegmentsAddedSoFar);
+    }
+
+    checkIfAnyServersAssignedMoreSegmentsThanBatchSize(batchSizePerServer, 
serverToNumSegmentsAddedSoFar,
+        tableRebalanceLogger);
+    return nextAssignment;
+  }
+
+  /**
+   * Create a mapping of partitionId to the current assignment of segments 
that belong to that partitionId. This is to
+   * be used for batching purposes for StrictReplicaGroup
+   * @param currentAssignment the current assignment
+   * @param segmentPartitionIdMap cache to store the partition ids to avoid 
fetching ZK segment metadata
+   * @param partitionIdFetcher function to fetch the partition id
+   * @return a mapping from partitionId to the segment assignment map of all 
segments that map to that partitionId
+   */
+  private static Map<Integer, Map<String, Map<String, String>>> 
getPartitionIdToCurrentAssignmentMap(
+      Map<String, Map<String, String>> currentAssignment, 
Object2IntOpenHashMap<String> segmentPartitionIdMap,
+      PartitionIdFetcher partitionIdFetcher) {
+    Map<Integer, Map<String, Map<String, String>>> 
partitionIdToCurrentAssignmentMap = new TreeMap<>();
+
+    for (Map.Entry<String, Map<String, String>> assignment : 
currentAssignment.entrySet()) {
+      String segmentName = assignment.getKey();
+      Map<String, String> instanceStateMap = assignment.getValue();
+
+      int partitionId =
+          segmentPartitionIdMap.computeIfAbsent(segmentName, v -> 
partitionIdFetcher.fetch(segmentName));
+      partitionIdToCurrentAssignmentMap.computeIfAbsent(partitionId,
+          k -> new TreeMap<>()).put(segmentName, instanceStateMap);
+    }
+
+    return partitionIdToCurrentAssignmentMap;
+  }
+
+  private static Map<String, Map<String, String>> 
getNextStrictReplicaGroupRoutingOnlyAssignment(
+      Map<String, Map<String, String>> currentAssignment, Map<String, 
Map<String, String>> targetAssignment,
+      int minAvailableReplicas, boolean lowDiskMode, int batchSizePerServer,
+      Object2IntOpenHashMap<String> segmentPartitionIdMap, PartitionIdFetcher 
partitionIdFetcher,
+      Logger tableRebalanceLogger) {
+    Map<String, Map<String, String>> nextAssignment = new TreeMap<>();
+    Map<String, Integer> numSegmentsToOffloadMap = 
getNumSegmentsToOffloadMap(currentAssignment, targetAssignment);
+    Map<Integer, Map<Set<String>, Map<String, Map<String, String>>>>
+        partitionIdToAssignedInstancesToCurrentAssignmentMap;
+    if (batchSizePerServer == RebalanceConfig.DISABLE_BATCH_SIZE_PER_SERVER) {
+      // Don't calculate the partition id to assigned instances to current 
assignment mapping if batching is disabled
+      // since we want to update the next assignment based on all partitions 
in this case. Use partitionId as 0
+      // and a dummy set for the assigned instances.
+      partitionIdToAssignedInstancesToCurrentAssignmentMap = new TreeMap<>();
+      partitionIdToAssignedInstancesToCurrentAssignmentMap.put(0, new 
HashMap<>());
+      
partitionIdToAssignedInstancesToCurrentAssignmentMap.get(0).put(Set.of(""), 
currentAssignment);
+    } else {
+      partitionIdToAssignedInstancesToCurrentAssignmentMap =
+          
getPartitionIdToAssignedInstancesToCurrentAssignmentMap(currentAssignment, 
segmentPartitionIdMap,
+              partitionIdFetcher);
+    }
+    Map<Pair<Set<String>, Set<String>>, Set<String>> assignmentMap = new 
HashMap<>();
+    Map<Set<String>, Set<String>> availableInstancesMap = new HashMap<>();
+
+    Map<String, Integer> serverToNumSegmentsAddedSoFar = new HashMap<>();
+    for (Map<Set<String>, Map<String, Map<String, String>>> 
assignedInstancesToCurrentAssignment
+        : partitionIdToAssignedInstancesToCurrentAssignmentMap.values()) {
+      boolean anyServerExhaustedBatchSize = false;
+      if (batchSizePerServer != RebalanceConfig.DISABLE_BATCH_SIZE_PER_SERVER) 
{
+        // Check if the servers of the first assignment for each unique set of 
assigned instances has any space left
+        // to move this partition. If so, let's mark the partitions as to be 
moved, otherwise we mark the partition
+        // as a whole as not moveable.
+        for (Map<String, Map<String, String>> curAssignment : 
assignedInstancesToCurrentAssignment.values()) {
+          Map.Entry<String, Map<String, String>> firstEntry = 
curAssignment.entrySet().iterator().next();
+          // All segments should be assigned to the same set of servers so it 
is enough to check for whether any server
+          // for one segment is above the limit or not
+          Map<String, String> firstEntryInstanceStateMap = 
firstEntry.getValue();
+          SingleSegmentAssignment firstAssignment =
+              getNextSingleSegmentAssignment(firstEntryInstanceStateMap, 
targetAssignment.get(firstEntry.getKey()),
+                  minAvailableReplicas, lowDiskMode, numSegmentsToOffloadMap, 
assignmentMap);
+          Set<String> serversAdded = 
getServersAddedInSingleSegmentAssignment(firstEntryInstanceStateMap,
+              firstAssignment._instanceStateMap);
+          for (String server : serversAdded) {
+            if (serverToNumSegmentsAddedSoFar.getOrDefault(server, 0) >= 
batchSizePerServer) {
+              anyServerExhaustedBatchSize = true;
+              break;
+            }
+          }
+          if (anyServerExhaustedBatchSize) {
+            break;
+          }
+        }
+      }
+      for (Map<String, Map<String, String>> curAssignment : 
assignedInstancesToCurrentAssignment.values()) {
+        getNextAssignmentForPartitionIdStrictReplicaGroup(curAssignment, 
targetAssignment, nextAssignment,
+            anyServerExhaustedBatchSize, minAvailableReplicas, lowDiskMode, 
numSegmentsToOffloadMap, assignmentMap,
+            availableInstancesMap, serverToNumSegmentsAddedSoFar);
+      }
+    }
+
+    checkIfAnyServersAssignedMoreSegmentsThanBatchSize(batchSizePerServer, 
serverToNumSegmentsAddedSoFar,
+        tableRebalanceLogger);
+    return nextAssignment;
+  }
+
+  private static void 
getNextAssignmentForPartitionIdStrictReplicaGroup(Map<String, Map<String, 
String>> curAssignment,
+      Map<String, Map<String, String>> targetAssignment, Map<String, 
Map<String, String>> nextAssignment,
+      boolean anyServerExhaustedBatchSize, int minAvailableReplicas, boolean 
lowDiskMode,
+      Map<String, Integer> numSegmentsToOffloadMap, Map<Pair<Set<String>, 
Set<String>>, Set<String>> assignmentMap,
+      Map<Set<String>, Set<String>> availableInstancesMap, Map<String, 
Integer> serverToNumSegmentsAddedSoFar) {
+    if (anyServerExhaustedBatchSize) {
+      // Exhausted the batch size for at least 1 server, just copy over the 
remaining segments as is
+      for (Map.Entry<String, Map<String, String>> entry : 
curAssignment.entrySet()) {
+        String segmentName = entry.getKey();
+        Map<String, String> currentInstanceStateMap = entry.getValue();
+        nextAssignment.put(segmentName, currentInstanceStateMap);
+      }
+    } else {
+      // Process all the partitionIds even if segmentsAddedSoFar becomes 
larger than batchSizePerServer
+      // Can only do bestEfforts w.r.t. StrictReplicaGroup since a whole 
partition must be moved together for
+      // maintaining consistency
+      for (Map.Entry<String, Map<String, String>> entry : 
curAssignment.entrySet()) {
+        String segmentName = entry.getKey();
+        Map<String, String> currentInstanceStateMap = entry.getValue();
+        Map<String, String> targetInstanceStateMap = 
targetAssignment.get(segmentName);
+        SingleSegmentAssignment assignment =
+            getNextSingleSegmentAssignment(currentInstanceStateMap, 
targetInstanceStateMap, minAvailableReplicas,
+                lowDiskMode, numSegmentsToOffloadMap, assignmentMap);
+        Set<String> assignedInstances = assignment._instanceStateMap.keySet();
+        Set<String> availableInstances = assignment._availableInstances;
+        availableInstancesMap.compute(assignedInstances, (k, 
currentAvailableInstances) -> {
+          if (currentAvailableInstances == null) {
+            // First segment assigned to these instances, use the new 
assignment and update the available instances
             nextAssignment.put(segmentName, assignment._instanceStateMap);
             updateNumSegmentsToOffloadMap(numSegmentsToOffloadMap, 
currentInstanceStateMap.keySet(), k);
             return availableInstances;
           } else {
-            // New assignment cannot be added, use the current instance state 
map
-            nextAssignment.put(segmentName, currentInstanceStateMap);
-            return currentAvailableInstances;
+            // There are other segments assigned to the same instances, check 
the available instances to see if
+            // adding the new assignment can still hold the minimum available 
replicas requirement
+            availableInstances.retainAll(currentAvailableInstances);
+            if (availableInstances.size() >= minAvailableReplicas) {
+              // New assignment can be added
+              nextAssignment.put(segmentName, assignment._instanceStateMap);
+              updateNumSegmentsToOffloadMap(numSegmentsToOffloadMap, 
currentInstanceStateMap.keySet(), k);
+              return availableInstances;
+            } else {
+              // New assignment cannot be added, use the current instance 
state map
+              nextAssignment.put(segmentName, currentInstanceStateMap);
+              return currentAvailableInstances;
+            }
+          }
+        });
+
+        if (!nextAssignment.get(segmentName).equals(currentInstanceStateMap)) {
+          Set<String> serversAddedForSegment = 
getServersAddedInSingleSegmentAssignment(currentInstanceStateMap,
+              nextAssignment.get(segmentName));
+          for (String server : serversAddedForSegment) {
+            int numSegmentsAdded = 
serverToNumSegmentsAddedSoFar.getOrDefault(server, 0);
+            serverToNumSegmentsAddedSoFar.put(server, numSegmentsAdded + 1);
           }
         }
-      });
+      }
+    }
+  }
+
+  private static void checkIfAnyServersAssignedMoreSegmentsThanBatchSize(int 
batchSizePerServer,
+      Map<String, Integer> serverToNumSegmentsAddedSoFar, Logger 
tableRebalanceLogger) {
+    int maxSegmentsAddedToAnyServer = serverToNumSegmentsAddedSoFar.isEmpty() 
? 0
+        : Collections.max(serverToNumSegmentsAddedSoFar.values());
+    if (batchSizePerServer != RebalanceConfig.DISABLE_BATCH_SIZE_PER_SERVER
+        && maxSegmentsAddedToAnyServer > batchSizePerServer) {
+      tableRebalanceLogger.warn("Found at least one server with {} segments 
added which is larger than "
+          + "batchSizePerServer: {}", maxSegmentsAddedToAnyServer, 
batchSizePerServer);
+    }
+  }
+
+  /**
+   * Create a mapping of partitionId to the mapping of assigned instances to 
the current assignment of segments that
+   * belong to that partitionId and assigned instances. This is to be used for 
batching purposes for StrictReplicaGroup
+   * routing only with non-StrictRealtimeSegmentAssignment
+   * @param currentAssignment the current assignment
+   * @param segmentPartitionIdMap cache to store the partition ids to avoid 
fetching ZK segment metadata
+   * @param partitionIdFetcher function to fetch the partition id
+   * @return a mapping from partitionId to the assigned instances to the 
segment assignment map of all segments that
+   *         map to that partitionId and assigned instances
+   */
+  private static Map<Integer, Map<Set<String>, Map<String, Map<String, 
String>>>>
+  getPartitionIdToAssignedInstancesToCurrentAssignmentMap(Map<String, 
Map<String, String>> currentAssignment,
+      Object2IntOpenHashMap<String> segmentPartitionIdMap, PartitionIdFetcher 
partitionIdFetcher) {
+    Map<Integer, Map<Set<String>, Map<String, Map<String, String>>>>
+        partitionIdToAssignedInstancesToCurrentAssignmentMap = new TreeMap<>();
+
+    for (Map.Entry<String, Map<String, String>> assignment : 
currentAssignment.entrySet()) {
+      String segmentName = assignment.getKey();
+      Map<String, String> instanceStateMap = assignment.getValue();
+
+      int partitionId =
+          segmentPartitionIdMap.computeIfAbsent(segmentName, v -> 
partitionIdFetcher.fetch(segmentName));
+      Set<String> assignedInstances = instanceStateMap.keySet();
+      
partitionIdToAssignedInstancesToCurrentAssignmentMap.putIfAbsent(partitionId, 
new HashMap<>());
+      partitionIdToAssignedInstancesToCurrentAssignmentMap.get(partitionId)
+          .computeIfAbsent(assignedInstances, k -> new 
TreeMap<>()).put(segmentName, instanceStateMap);
+    }
+
+    return partitionIdToAssignedInstancesToCurrentAssignmentMap;
+  }
+
+  @VisibleForTesting
+  @FunctionalInterface
+  interface PartitionIdFetcher {
+    int fetch(String segmentName);
+  }
+
+  private static class PartitionIdFetcherImpl implements PartitionIdFetcher {
+    private final String _tableNameWithType;
+    private final String _partitionColumn;
+    private final HelixManager _helixManager;
+    private final boolean _isStrictRealtimeSegmentAssignment;
+
+    private PartitionIdFetcherImpl(String tableNameWithType, @Nullable String 
partitionColumn,
+        HelixManager helixManager, boolean isStrictRealtimeSegmentAssignment) {
+      _tableNameWithType = tableNameWithType;
+      _partitionColumn = partitionColumn;
+      _helixManager = helixManager;
+      _isStrictRealtimeSegmentAssignment = isStrictRealtimeSegmentAssignment;
+    }
+
+    @Override
+    public int fetch(String segmentName) {
+      Integer partitionId;
+      if (_isStrictRealtimeSegmentAssignment) {
+        // This is how partitionId is calculated for 
StrictRealtimeSegmentAssignment
+        partitionId =
+            SegmentUtils.getRealtimeSegmentPartitionId(segmentName, 
_tableNameWithType, _helixManager,
+                _partitionColumn);
+        Preconditions.checkState(partitionId != null, "Failed to find 
partition id for segment: %s of table: %s",
+            segmentName, _tableNameWithType);
+      } else {
+        // This how partitionId is calculated for RealtimeSegmentAssignment
+        partitionId = 
SegmentAssignmentUtils.getRealtimeSegmentPartitionId(segmentName, 
_tableNameWithType,
+            _helixManager, _partitionColumn);
+      }
+      return partitionId;

Review Comment:
   This could be technically called from 
`getNextStrictReplicaGroupRoutingOnlyAssignment` where an `OFFLINE` table has 
strict replica group routing configured right? It _looks_ like this logic might 
still work even though all the util function names indicate its for `REALTIME` 
tables only because if unable to retrieve the partition ID from the segment 
name, we try to get it from the segment metadata (which should work for 
`OFFLINE` tables as well)?



##########
pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/TableRebalancer.java:
##########
@@ -1524,67 +1535,336 @@ private static void handleErrorInstance(String 
tableNameWithType, String segment
     }
   }
 
+  /**
+   * Uses the default LOGGER
+   */
+  @VisibleForTesting
+  static Map<String, Map<String, String>> getNextAssignment(Map<String, 
Map<String, String>> currentAssignment,
+      Map<String, Map<String, String>> targetAssignment, int 
minAvailableReplicas, boolean enableStrictReplicaGroup,
+      boolean lowDiskMode, int batchSizePerServer, 
Object2IntOpenHashMap<String> segmentPartitionIdMap,
+      PartitionIdFetcher partitionIdFetcher, boolean 
isStrictRealtimeSegmentAssignment) {
+    return getNextAssignment(currentAssignment, targetAssignment, 
minAvailableReplicas, enableStrictReplicaGroup,
+        lowDiskMode, batchSizePerServer, segmentPartitionIdMap, 
partitionIdFetcher, isStrictRealtimeSegmentAssignment,
+        LOGGER);
+  }
+
   /**
    * Returns the next assignment for the table based on the current assignment 
and the target assignment with regard to
    * the minimum available replicas requirement. For strict replica-group 
mode, track the available instances for all
    * the segments with the same instances in the next assignment, and ensure 
the minimum available replicas requirement
    * is met. If adding the assignment for a segment breaks the requirement, 
use the current assignment for the segment.
+   *
+   * For strict replica group routing only (where the segment assignment is 
not StrictRealtimeSegmentAssignment)
+   * if batching is enabled, don't group the assignment by partitionId, since 
the segments of the same partitionId do
+   * not need to be assigned to the same servers. For strict replica group 
routing with strict replica group
+   * assignment on the other hand, group the assignment by partitionId since a 
partition must move as a whole, and they
+   * have the same servers assigned across all segments belonging to the same 
partitionId.
+   *
+   * TODO: Ideally if strict replica group routing is enabled then 
StrictRealtimeSegmentAssignment should be used, but

Review Comment:
   Just for my understanding - today we allow strict replica group routing to 
be enabled for OFFLINE tables and non-upsert REALTIME tables, but there's no 
real benefit in doing so and we plan to remove this unintended functionality in 
the future?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to