This is an automated email from the ASF dual-hosted git repository.

jxue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/helix.git

commit 84cf7c2143348209c446aa4a4351e6a158742b83
Author: frankmu <[email protected]>
AuthorDate: Thu Oct 31 08:50:23 2024 -0700

    Make StickyRebalanceStrategy topology aware (#2944)
    
    * Make StickyRebalanceStrategy topology aware
    
    * move canAdd() to a separate function for extendability
    
    * add todo
    
    * address comment and add test
    
    * constructor change for CapacityNode
    
    ---------
    
    Co-authored-by: Tengfei Mu <[email protected]>
---
 .../helix/controller/common/CapacityNode.java      | 104 ++++++++++++++++++---
 .../ResourceControllerDataProvider.java            |  19 ++--
 .../strategy/StickyRebalanceStrategy.java          |  55 ++++++++---
 .../java/org/apache/helix/common/ZkTestBase.java   |   8 --
 .../rebalancer/TestStickyRebalanceStrategy.java    |  29 +++---
 .../rebalancer/TestStickyRebalanceStrategy.java    |  88 ++++++++++++-----
 6 files changed, 227 insertions(+), 76 deletions(-)

diff --git 
a/helix-core/src/main/java/org/apache/helix/controller/common/CapacityNode.java 
b/helix-core/src/main/java/org/apache/helix/controller/common/CapacityNode.java
index 208ee7913..bd49040ff 100644
--- 
a/helix-core/src/main/java/org/apache/helix/controller/common/CapacityNode.java
+++ 
b/helix-core/src/main/java/org/apache/helix/controller/common/CapacityNode.java
@@ -21,23 +21,55 @@ package org.apache.helix.controller.common;
 
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.LinkedHashMap;
 import java.util.Map;
 import java.util.Set;
 
+import org.apache.helix.controller.rebalancer.topology.Topology;
+import org.apache.helix.model.ClusterConfig;
+import org.apache.helix.model.ClusterTopologyConfig;
+import org.apache.helix.model.InstanceConfig;
+
 /**
  * A Node is an entity that can serve capacity recording purpose. It has a 
capacity and knowledge
  * of partitions assigned to it, so it can decide if it can receive additional 
partitions.
  */
-public class CapacityNode {
+public class CapacityNode implements Comparable<CapacityNode> {
   private int _currentlyAssigned;
   private int _capacity;
-  private final String _id;
+  private final String _instanceName;
+  private final String _logicaId;
+  private final String _faultZone;
   private final Map<String, Set<String>> _partitionMap;
 
-  public CapacityNode(String id) {
-    _partitionMap = new HashMap<>();
-    _currentlyAssigned = 0;
-    this._id = id;
+  /**
+   * Constructor used for non-topology-aware use case
+   * @param instanceName  The instance name of this node
+   * @param capacity  The capacity of this node
+   */
+  public CapacityNode(String instanceName, int capacity) {
+    this(instanceName, null, null, null);
+    this._capacity = capacity;
+  }
+
+  /**
+   * Constructor used for topology-aware use case
+   * @param instanceName  The instance name of this node
+   * @param clusterConfig  The cluster config for current helix cluster
+   * @param clusterTopologyConfig  The cluster topology config for current 
helix cluster
+   * @param instanceConfig  The instance config for current instance
+   */
+  public CapacityNode(String instanceName, ClusterConfig clusterConfig,
+      ClusterTopologyConfig clusterTopologyConfig, InstanceConfig 
instanceConfig) {
+    this._instanceName = instanceName;
+    this._logicaId = clusterTopologyConfig != null ? 
instanceConfig.getLogicalId(
+        clusterTopologyConfig.getEndNodeType()) : instanceName;
+    this._faultZone =
+        clusterConfig != null ? computeFaultZone(clusterConfig, 
instanceConfig) : null;
+    this._partitionMap = new HashMap<>();
+    this._capacity =
+        clusterConfig != null ? 
clusterConfig.getGlobalMaxPartitionAllowedPerInstance() : 0;
+    this._currentlyAssigned = 0;
   }
 
   /**
@@ -80,11 +112,27 @@ public class CapacityNode {
   }
 
   /**
-   * Get the ID of this node
-   * @return The ID of this node
+   * Get the instance name of this node
+   * @return The instance name of this node
    */
-  public String getId() {
-    return _id;
+  public String getInstanceName() {
+    return _instanceName;
+  }
+
+  /**
+   * Get the logical id of this node
+   * @return The logical id of this node
+   */
+  public String getLogicalId() {
+    return _logicaId;
+  }
+
+  /**
+   * Get the fault zone of this node
+   * @return The fault zone of this node
+   */
+  public String getFaultZone() {
+    return _faultZone;
   }
 
   /**
@@ -98,8 +146,40 @@ public class CapacityNode {
   @Override
   public String toString() {
     StringBuilder sb = new StringBuilder();
-    
sb.append("##########\nname=").append(_id).append("\nassigned:").append(_currentlyAssigned)
-        .append("\ncapacity:").append(_capacity);
+    sb.append("##########\nname=").append(_instanceName).append("\nassigned:")
+        
.append(_currentlyAssigned).append("\ncapacity:").append(_capacity).append("\nlogicalId:")
+        .append(_logicaId).append("\nfaultZone:").append(_faultZone);
     return sb.toString();
   }
+
+  @Override
+  public int compareTo(CapacityNode o) {
+    if (_logicaId != null) {
+      return _logicaId.compareTo(o.getLogicalId());
+    }
+    return _instanceName.compareTo(o.getInstanceName());
+  }
+
+  /**
+   * Computes the fault zone id based on the domain and fault zone type when 
topology is enabled.
+   * For example, when
+   * the domain is "zone=2, instance=testInstance" and the fault zone type is 
"zone", this function
+   * returns "2".
+   * If cannot find the fault zone type, this function leaves the fault zone 
id as the instance name.
+   * TODO: change the return value to logical id when no fault zone type 
found. Also do the same for
+   *  waged rebalancer in 
helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/model/AssignableNode.java
+   */
+  private String computeFaultZone(ClusterConfig clusterConfig, InstanceConfig 
instanceConfig) {
+    LinkedHashMap<String, String> instanceTopologyMap =
+        Topology.computeInstanceTopologyMap(clusterConfig, 
instanceConfig.getInstanceName(),
+            instanceConfig, true /*earlyQuitTillFaultZone*/);
+
+    StringBuilder faultZoneStringBuilder = new StringBuilder();
+    for (Map.Entry<String, String> entry : instanceTopologyMap.entrySet()) {
+      faultZoneStringBuilder.append(entry.getValue());
+      faultZoneStringBuilder.append('/');
+    }
+    faultZoneStringBuilder.setLength(faultZoneStringBuilder.length() - 1);
+    return faultZoneStringBuilder.toString();
+  }
 }
diff --git 
a/helix-core/src/main/java/org/apache/helix/controller/dataproviders/ResourceControllerDataProvider.java
 
b/helix-core/src/main/java/org/apache/helix/controller/dataproviders/ResourceControllerDataProvider.java
index eae6af8ce..50088e69c 100644
--- 
a/helix-core/src/main/java/org/apache/helix/controller/dataproviders/ResourceControllerDataProvider.java
+++ 
b/helix-core/src/main/java/org/apache/helix/controller/dataproviders/ResourceControllerDataProvider.java
@@ -46,6 +46,8 @@ import 
org.apache.helix.controller.rebalancer.waged.WagedInstanceCapacity;
 import 
org.apache.helix.controller.rebalancer.waged.WagedResourceWeightsProvider;
 import org.apache.helix.controller.stages.CurrentStateOutput;
 import org.apache.helix.controller.stages.MissingTopStateRecord;
+import org.apache.helix.model.ClusterConfig;
+import org.apache.helix.model.ClusterTopologyConfig;
 import org.apache.helix.model.CustomizedState;
 import org.apache.helix.model.CustomizedStateConfig;
 import org.apache.helix.model.CustomizedView;
@@ -190,7 +192,7 @@ public class ResourceControllerDataProvider extends 
BaseControllerDataProvider {
 
     if (getClusterConfig() != null
         && getClusterConfig().getGlobalMaxPartitionAllowedPerInstance() != -1) 
{
-      
buildSimpleCapacityMap(getClusterConfig().getGlobalMaxPartitionAllowedPerInstance());
+      buildSimpleCapacityMap();
       // Remove all cached IdealState because it is a global computation 
cannot partially be
       // performed for some resources. The computation is simple as well not 
taking too much resource
       // to recompute the assignments.
@@ -581,11 +583,16 @@ public class ResourceControllerDataProvider extends 
BaseControllerDataProvider {
     return _wagedInstanceCapacity;
   }
 
-  private void buildSimpleCapacityMap(int 
globalMaxPartitionAllowedPerInstance) {
+  private void buildSimpleCapacityMap() {
+    ClusterConfig clusterConfig = getClusterConfig();
+    ClusterTopologyConfig clusterTopologyConfig =
+        ClusterTopologyConfig.createFromClusterConfig(clusterConfig);
+    Map<String, InstanceConfig> instanceConfigMap = 
getAssignableInstanceConfigMap();
     _simpleCapacitySet = new HashSet<>();
-    for (String instance : getEnabledLiveInstances()) {
-      CapacityNode capacityNode = new CapacityNode(instance);
-      capacityNode.setCapacity(globalMaxPartitionAllowedPerInstance);
+    for (String instanceName : getAssignableInstances()) {
+      CapacityNode capacityNode =
+          new CapacityNode(instanceName, clusterConfig, clusterTopologyConfig,
+              instanceConfigMap.getOrDefault(instanceName, new 
InstanceConfig(instanceName)));
       _simpleCapacitySet.add(capacityNode);
     }
   }
@@ -599,7 +606,7 @@ public class ResourceControllerDataProvider extends 
BaseControllerDataProvider {
     // Convert the assignableNodes to map for quick lookup
     Map<String, CapacityNode> simpleCapacityMap = new HashMap<>();
     for (CapacityNode node : _simpleCapacitySet) {
-      simpleCapacityMap.put(node.getId(), node);
+      simpleCapacityMap.put(node.getInstanceName(), node);
     }
     for (String resourceName : resourceNameSet) {
       // Process current state mapping
diff --git 
a/helix-core/src/main/java/org/apache/helix/controller/rebalancer/strategy/StickyRebalanceStrategy.java
 
b/helix-core/src/main/java/org/apache/helix/controller/rebalancer/strategy/StickyRebalanceStrategy.java
index 3c3793cec..0471f128e 100644
--- 
a/helix-core/src/main/java/org/apache/helix/controller/rebalancer/strategy/StickyRebalanceStrategy.java
+++ 
b/helix-core/src/main/java/org/apache/helix/controller/rebalancer/strategy/StickyRebalanceStrategy.java
@@ -20,7 +20,6 @@ package org.apache.helix.controller.rebalancer.strategy;
  */
 
 import java.util.ArrayList;
-import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.LinkedHashMap;
@@ -74,11 +73,15 @@ public class StickyRebalanceStrategy implements 
RebalanceStrategy<ResourceContro
     // Note the liveNodes parameter here might be processed within the 
rebalancer, e.g. filter based on tags
     Set<CapacityNode> assignableNodeSet = new 
HashSet<>(clusterData.getSimpleCapacitySet());
     Set<String> liveNodesSet = new HashSet<>(liveNodes);
-    assignableNodeSet.removeIf(n -> !liveNodesSet.contains(n.getId()));
+    assignableNodeSet.removeIf(n -> 
!liveNodesSet.contains(n.getInstanceName()));
+
+    // Convert the assignableNodes to map for quick lookup
+    Map<String, CapacityNode> assignableNodeMap = assignableNodeSet.stream()
+        .collect(Collectors.toMap(CapacityNode::getInstanceName, node -> 
node));
 
     //  Populate valid state map given current mapping
     Map<String, Set<String>> stateMap =
-        populateValidAssignmentMapFromCurrentMapping(currentMapping, 
assignableNodeSet);
+        populateValidAssignmentMapFromCurrentMapping(currentMapping, 
assignableNodeMap);
 
     if (logger.isDebugEnabled()) {
       logger.debug("currentMapping: {}", currentMapping);
@@ -86,22 +89,33 @@ public class StickyRebalanceStrategy implements 
RebalanceStrategy<ResourceContro
     }
 
     // Sort the assignable nodes by id
-    List<CapacityNode> assignableNodeList =
-        
assignableNodeSet.stream().sorted(Comparator.comparing(CapacityNode::getId))
+    List<CapacityNode> assignableNodeList = assignableNodeSet.stream().sorted()
             .collect(Collectors.toList());
 
     // Assign partitions to node by order.
     for (int i = 0, index = 0; i < _partitions.size(); i++) {
       int startIndex = index;
+      Map<String, Integer> currentFaultZoneCountMap = new HashMap<>();
       int remainingReplica = _statesReplicaCount;
       if (stateMap.containsKey(_partitions.get(i))) {
-        remainingReplica = remainingReplica - 
stateMap.get(_partitions.get(i)).size();
+        Set<String> existingReplicas = stateMap.get(_partitions.get(i));
+        remainingReplica = remainingReplica - existingReplicas.size();
+        for (String instanceName : existingReplicas) {
+          String faultZone = 
assignableNodeMap.get(instanceName).getFaultZone();
+          currentFaultZoneCountMap.put(faultZone,
+              currentFaultZoneCountMap.getOrDefault(faultZone, 0) + 1);
+        }
       }
       for (int j = 0; j < remainingReplica; j++) {
         while (index - startIndex < assignableNodeList.size()) {
           CapacityNode node = assignableNodeList.get(index++ % 
assignableNodeList.size());
-          if (node.canAdd(_resourceName, _partitions.get(i))) {
-            stateMap.computeIfAbsent(_partitions.get(i), m -> new 
HashSet<>()).add(node.getId());
+          if (this.canAdd(node, _partitions.get(i), currentFaultZoneCountMap)) 
{
+            stateMap.computeIfAbsent(_partitions.get(i), m -> new HashSet<>())
+                .add(node.getInstanceName());
+            if (node.getFaultZone() != null) {
+              currentFaultZoneCountMap.put(node.getFaultZone(),
+                  currentFaultZoneCountMap.getOrDefault(node.getFaultZone(), 
0) + 1);
+            }
             break;
           }
         }
@@ -126,16 +140,13 @@ public class StickyRebalanceStrategy implements 
RebalanceStrategy<ResourceContro
    * Populates a valid state map from the current mapping, filtering out 
invalid nodes.
    *
    * @param currentMapping   the current mapping of partitions to node states
-   * @param assignableNodes  the list of nodes that can be assigned
+   * @param assignableNodeMap  the map of instance name -> nodes that can be 
assigned
    * @return a map of partitions to valid nodes
    */
   private Map<String, Set<String>> 
populateValidAssignmentMapFromCurrentMapping(
       final Map<String, Map<String, String>> currentMapping,
-      final Set<CapacityNode> assignableNodes) {
+      final Map<String, CapacityNode> assignableNodeMap) {
     Map<String, Set<String>> validAssignmentMap = new HashMap<>();
-    // Convert the assignableNodes to map for quick lookup
-    Map<String, CapacityNode> assignableNodeMap =
-        assignableNodes.stream().collect(Collectors.toMap(CapacityNode::getId, 
node -> node));
     if (currentMapping != null) {
       for (Map.Entry<String, Map<String, String>> entry : 
currentMapping.entrySet()) {
         String partition = entry.getKey();
@@ -167,4 +178,22 @@ public class StickyRebalanceStrategy implements 
RebalanceStrategy<ResourceContro
     return node != null && (node.hasPartition(_resourceName, partition) || 
node.canAdd(
         _resourceName, partition));
   }
+
+  /**
+   * Checks if it's valid to assign the partition to node
+   *
+   * @param node           node to assign partition
+   * @param partition      partition name
+   * @param currentFaultZoneCountMap   the map of fault zones -> count
+   * @return true if it's valid to assign the partition to node, false 
otherwise
+   */
+  protected boolean canAdd(CapacityNode node, String partition,
+      Map<String, Integer> currentFaultZoneCountMap) {
+    // Valid assignment when following conditions match:
+    // 1. Replica is not within the same fault zones of other replicas
+    // 2. Node has capacity to hold the replica
+    return !currentFaultZoneCountMap.containsKey(node.getFaultZone()) && 
node.canAdd(_resourceName,
+        partition);
+  }
 }
+
diff --git a/helix-core/src/test/java/org/apache/helix/common/ZkTestBase.java 
b/helix-core/src/test/java/org/apache/helix/common/ZkTestBase.java
index 1f08fb570..50ab9c93a 100644
--- a/helix-core/src/test/java/org/apache/helix/common/ZkTestBase.java
+++ b/helix-core/src/test/java/org/apache/helix/common/ZkTestBase.java
@@ -367,14 +367,6 @@ public class ZkTestBase {
     configAccessor.setClusterConfig(clusterName, clusterConfig);
   }
 
-  protected void 
setGlobalMaxPartitionAllowedPerInstanceInCluster(HelixZkClient zkClient,
-      String clusterName, int maxPartitionAllowed) {
-    ConfigAccessor configAccessor = new ConfigAccessor(zkClient);
-    ClusterConfig clusterConfig = configAccessor.getClusterConfig(clusterName);
-    clusterConfig.setGlobalMaxPartitionAllowedPerInstance(maxPartitionAllowed);
-    configAccessor.setClusterConfig(clusterName, clusterConfig);
-  }
-
   protected IdealState createResourceWithDelayedRebalance(String clusterName, 
String db,
       String stateModel, int numPartition, int replica, int minActiveReplica, 
long delay) {
     return createResourceWithDelayedRebalance(clusterName, db, stateModel, 
numPartition, replica,
diff --git 
a/helix-core/src/test/java/org/apache/helix/controller/rebalancer/TestStickyRebalanceStrategy.java
 
b/helix-core/src/test/java/org/apache/helix/controller/rebalancer/TestStickyRebalanceStrategy.java
index 45211df4e..acd30c2c7 100644
--- 
a/helix-core/src/test/java/org/apache/helix/controller/rebalancer/TestStickyRebalanceStrategy.java
+++ 
b/helix-core/src/test/java/org/apache/helix/controller/rebalancer/TestStickyRebalanceStrategy.java
@@ -53,13 +53,12 @@ public class TestStickyRebalanceStrategy {
 
     Set<CapacityNode> capacityNodeSet = new HashSet<>();
     for (int i = 0; i < 5; i++) {
-      CapacityNode capacityNode = new CapacityNode("Node-" + i);
-      capacityNode.setCapacity(1);
+      CapacityNode capacityNode = new CapacityNode("Node-" + i, 1);
       capacityNodeSet.add(capacityNode);
     }
 
     List<String> liveNodes =
-        
capacityNodeSet.stream().map(CapacityNode::getId).collect(Collectors.toList());
+        
capacityNodeSet.stream().map(CapacityNode::getInstanceName).collect(Collectors.toList());
 
     List<String> partitions = new ArrayList<>();
     for (int i = 0; i < 3; i++) {
@@ -97,13 +96,12 @@ public class TestStickyRebalanceStrategy {
 
     Set<CapacityNode> capacityNodeSet = new HashSet<>();
     for (int i = 0; i < nNode; i++) {
-      CapacityNode capacityNode = new CapacityNode("Node-" + i);
-      capacityNode.setCapacity(1);
+      CapacityNode capacityNode = new CapacityNode("Node-" + i, 1);
       capacityNodeSet.add(capacityNode);
     }
 
     List<String> liveNodes =
-        
capacityNodeSet.stream().map(CapacityNode::getId).collect(Collectors.toList());
+        
capacityNodeSet.stream().map(CapacityNode::getInstanceName).collect(Collectors.toList());
 
     List<String> partitions = new ArrayList<>();
     for (int i = 0; i < nPartitions; i++) {
@@ -150,13 +148,12 @@ public class TestStickyRebalanceStrategy {
 
     Set<CapacityNode> capacityNodeSet = new HashSet<>();
     for (int i = 0; i < nNode; i++) {
-      CapacityNode capacityNode = new CapacityNode("Node-" + i);
-      capacityNode.setCapacity(1);
+      CapacityNode capacityNode = new CapacityNode("Node-" + i, 1);
       capacityNodeSet.add(capacityNode);
     }
 
     List<String> liveNodes =
-        
capacityNodeSet.stream().map(CapacityNode::getId).collect(Collectors.toList());
+        
capacityNodeSet.stream().map(CapacityNode::getInstanceName).collect(Collectors.toList());
 
     List<String> partitions = new ArrayList<>();
     for (int i = 0; i < nPartitions; i++) {
@@ -164,13 +161,13 @@ public class TestStickyRebalanceStrategy {
     }
     when(clusterDataCache.getSimpleCapacitySet()).thenReturn(capacityNodeSet);
 
-    StickyRebalanceStrategy greedyRebalanceStrategy = new 
StickyRebalanceStrategy();
-    greedyRebalanceStrategy.init(TEST_RESOURCE_PREFIX + 0, partitions, states, 
1);
+    StickyRebalanceStrategy stickyRebalanceStrategy = new 
StickyRebalanceStrategy();
+    stickyRebalanceStrategy.init(TEST_RESOURCE_PREFIX + 0, partitions, states, 
1);
     // First round assignment computation:
     // 1. Without previous assignment (currentMapping is null)
     // 2. Without enough assignable nodes
     ZNRecord firstRoundShardAssignment =
-        greedyRebalanceStrategy.computePartitionAssignment(null, liveNodes, 
null, clusterDataCache);
+        stickyRebalanceStrategy.computePartitionAssignment(null, liveNodes, 
null, clusterDataCache);
 
     // Assert only 3 partitions are fulfilled with assignment
     
Assert.assertEquals(firstRoundShardAssignment.getListFields().entrySet().stream()
@@ -178,12 +175,12 @@ public class TestStickyRebalanceStrategy {
 
     // Assign 4 more nodes which is used in second round assignment computation
     for (int i = nNode; i < nNode + 4; i++) {
-      CapacityNode capacityNode = new CapacityNode("Node-" + i);
-      capacityNode.setCapacity(1);
+      CapacityNode capacityNode = new CapacityNode("Node-" + i, 1);
       capacityNodeSet.add(capacityNode);
     }
 
-    liveNodes = 
capacityNodeSet.stream().map(CapacityNode::getId).collect(Collectors.toList());
+    liveNodes =
+        
capacityNodeSet.stream().map(CapacityNode::getInstanceName).collect(Collectors.toList());
 
     // Populate previous assignment (currentMapping) with first round 
assignment computation result
     Map<String, Map<String, String>> currentMapping = new HashMap<>();
@@ -199,7 +196,7 @@ public class TestStickyRebalanceStrategy {
     // 1. With previous assignment (currentMapping)
     // 2. With enough assignable nodes
     ZNRecord secondRoundShardAssignment =
-        greedyRebalanceStrategy.computePartitionAssignment(null, liveNodes, 
currentMapping,
+        stickyRebalanceStrategy.computePartitionAssignment(null, liveNodes, 
currentMapping,
             clusterDataCache);
 
     // Assert all partitions have been assigned with enough replica
diff --git 
a/helix-core/src/test/java/org/apache/helix/integration/rebalancer/TestStickyRebalanceStrategy.java
 
b/helix-core/src/test/java/org/apache/helix/integration/rebalancer/TestStickyRebalanceStrategy.java
index 97a27017e..04860d692 100644
--- 
a/helix-core/src/test/java/org/apache/helix/integration/rebalancer/TestStickyRebalanceStrategy.java
+++ 
b/helix-core/src/test/java/org/apache/helix/integration/rebalancer/TestStickyRebalanceStrategy.java
@@ -34,8 +34,10 @@ import 
org.apache.helix.controller.rebalancer.strategy.StickyRebalanceStrategy;
 import org.apache.helix.integration.manager.ClusterControllerManager;
 import org.apache.helix.integration.manager.MockParticipantManager;
 import org.apache.helix.model.BuiltInStateModelDefinitions;
+import org.apache.helix.model.ClusterConfig;
 import org.apache.helix.model.ExternalView;
 import org.apache.helix.model.IdealState;
+import org.apache.helix.model.InstanceConfig;
 import 
org.apache.helix.tools.ClusterVerifiers.BestPossibleExternalViewVerifier;
 import org.apache.helix.tools.ClusterVerifiers.ZkHelixClusterVerifier;
 import org.testng.Assert;
@@ -57,6 +59,7 @@ public class TestStickyRebalanceStrategy extends ZkTestBase {
   protected ClusterControllerManager _controller;
   protected List<MockParticipantManager> _participants = new ArrayList<>();
   protected List<MockParticipantManager> _additionalParticipants = new 
ArrayList<>();
+  protected Map<String, String> _instanceNameZoneMap = new HashMap<>();
   protected int _minActiveReplica = 0;
   protected ZkHelixClusterVerifier _clusterVerifier;
   protected List<String> _testDBs = new ArrayList<>();
@@ -67,27 +70,17 @@ public class TestStickyRebalanceStrategy extends ZkTestBase 
{
   @BeforeClass
   public void beforeClass() throws Exception {
     System.out.println("START " + CLASS_NAME + " at " + new 
Date(System.currentTimeMillis()));
+    _configAccessor = new ConfigAccessor(_gZkClient);
 
     _gSetupTool.addCluster(CLUSTER_NAME, true);
 
     for (int i = 0; i < NUM_NODE; i++) {
-      String storageNodeName = PARTICIPANT_PREFIX + "_" + (START_PORT + i);
-      _gSetupTool.addInstanceToCluster(CLUSTER_NAME, storageNodeName);
-
-      // start dummy participants
-      MockParticipantManager participant =
-          new MockParticipantManager(ZK_ADDR, CLUSTER_NAME, storageNodeName);
-      participant.syncStart();
-      _participants.add(participant);
+      _participants.addAll(addInstance("" + START_PORT + i, "zone-" + i % 
REPLICAS, true));
     }
 
     for (int i = NUM_NODE; i < NUM_NODE + ADDITIONAL_NUM_NODE; i++) {
-      String storageNodeName = PARTICIPANT_PREFIX + "_" + (START_PORT + i);
-      _gSetupTool.addInstanceToCluster(CLUSTER_NAME, storageNodeName);
-
-      MockParticipantManager participant =
-          new MockParticipantManager(ZK_ADDR, CLUSTER_NAME, storageNodeName);
-      _additionalParticipants.add(participant);
+      _additionalParticipants.addAll(
+          addInstance("" + START_PORT + i, "zone-" + i % REPLICAS, false));
     }
 
     // start controller
@@ -147,9 +140,27 @@ public class TestStickyRebalanceStrategy extends 
ZkTestBase {
     _clusterVerifier.verifyByPolling();
   }
 
+  @Test
+  public void testNoSameZoneAssignment() throws Exception {
+    
setTopologyAwareAndGlobalMaxPartitionAllowedPerInstanceInCluster(CLUSTER_NAME, 
1);
+    Map<String, ExternalView> externalViews = createTestDBs();
+    for (ExternalView ev : externalViews.values()) {
+      Map<String, Map<String, String>> assignments = 
ev.getRecord().getMapFields();
+      Assert.assertNotNull(assignments);
+      Assert.assertEquals(assignments.size(), PARTITIONS);
+      for (Map<String, String> assignmentMap : assignments.values()) {
+        Assert.assertEquals(assignmentMap.keySet().size(), REPLICAS);
+        Set<String> zoneSet = new HashSet<>();
+        for (String instanceName : assignmentMap.keySet()) {
+          zoneSet.add(_instanceNameZoneMap.get(instanceName));
+        }
+        Assert.assertEquals(zoneSet.size(), REPLICAS);
+      }
+    }
+  }
   @Test
   public void testFirstTimeAssignmentWithNoInitialLiveNodes() throws Exception 
{
-    setGlobalMaxPartitionAllowedPerInstanceInCluster(_gZkClient, CLUSTER_NAME, 
1);
+    
setTopologyAwareAndGlobalMaxPartitionAllowedPerInstanceInCluster(CLUSTER_NAME, 
1);
     // Shut down all the nodes
     for (int i = 0; i < NUM_NODE; i++) {
       _participants.get(i).syncStop();
@@ -175,7 +186,7 @@ public class TestStickyRebalanceStrategy extends ZkTestBase 
{
 
   @Test
   public void testNoPartitionMovementWithNewInstanceAdd() throws Exception {
-    setGlobalMaxPartitionAllowedPerInstanceInCluster(_gZkClient, CLUSTER_NAME, 
1);
+    
setTopologyAwareAndGlobalMaxPartitionAllowedPerInstanceInCluster(CLUSTER_NAME, 
1);
     Map<String, ExternalView> externalViewsBefore = createTestDBs();
 
     // Start more new instances
@@ -197,7 +208,7 @@ public class TestStickyRebalanceStrategy extends ZkTestBase 
{
 
   @Test
   public void testNoPartitionMovementWithInstanceDown() throws Exception {
-    setGlobalMaxPartitionAllowedPerInstanceInCluster(_gZkClient, CLUSTER_NAME, 
1);
+    
setTopologyAwareAndGlobalMaxPartitionAllowedPerInstanceInCluster(CLUSTER_NAME, 
1);
     Map<String, ExternalView> externalViewsBefore = createTestDBs();
 
     // Shut down 2 instances
@@ -220,7 +231,7 @@ public class TestStickyRebalanceStrategy extends ZkTestBase 
{
 
   @Test
   public void testNoPartitionMovementWithInstanceRestart() throws Exception {
-    setGlobalMaxPartitionAllowedPerInstanceInCluster(_gZkClient, CLUSTER_NAME, 
1);
+    
setTopologyAwareAndGlobalMaxPartitionAllowedPerInstanceInCluster(CLUSTER_NAME, 
1);
     // Create resource
     Map<String, ExternalView> externalViewsBefore = createTestDBs();
     // Shut down half of the nodes
@@ -265,14 +276,14 @@ public class TestStickyRebalanceStrategy extends 
ZkTestBase {
 
   @Test
   public void testFirstTimeAssignmentWithStackingPlacement() throws Exception {
-    setGlobalMaxPartitionAllowedPerInstanceInCluster(_gZkClient, CLUSTER_NAME, 
2);
+    
setTopologyAwareAndGlobalMaxPartitionAllowedPerInstanceInCluster(CLUSTER_NAME, 
2);
     Map<String, ExternalView> externalViewsBefore = createTestDBs();
     validateAllPartitionAssigned(externalViewsBefore);
   }
 
   @Test
   public void testNoPartitionMovementWithNewInstanceAddWithStackingPlacement() 
throws Exception {
-    setGlobalMaxPartitionAllowedPerInstanceInCluster(_gZkClient, CLUSTER_NAME, 
2);
+    
setTopologyAwareAndGlobalMaxPartitionAllowedPerInstanceInCluster(CLUSTER_NAME, 
2);
     Map<String, ExternalView> externalViewsBefore = createTestDBs();
 
     // Start more new instances
@@ -294,7 +305,7 @@ public class TestStickyRebalanceStrategy extends ZkTestBase 
{
 
   @Test
   public void testNoPartitionMovementWithInstanceDownWithStackingPlacement() 
throws Exception {
-    setGlobalMaxPartitionAllowedPerInstanceInCluster(_gZkClient, CLUSTER_NAME, 
2);
+    
setTopologyAwareAndGlobalMaxPartitionAllowedPerInstanceInCluster(CLUSTER_NAME, 
2);
     // Shut down half of the nodes given we allow stacking placement
     for (int i = 0; i < NUM_NODE / 2; i++) {
       _participants.get(i).syncStop();
@@ -395,4 +406,39 @@ public class TestStickyRebalanceStrategy extends 
ZkTestBase {
       }
     }
   }
+
+  private void 
setTopologyAwareAndGlobalMaxPartitionAllowedPerInstanceInCluster(String 
clusterName,
+      int maxPartitionAllowed) {
+    ClusterConfig clusterConfig = 
_configAccessor.getClusterConfig(clusterName);
+    clusterConfig.setTopology("/zone/host/logicalId");
+    clusterConfig.setFaultZoneType("zone");
+    clusterConfig.setTopologyAwareEnabled(true);
+    clusterConfig.setGlobalMaxPartitionAllowedPerInstance(maxPartitionAllowed);
+    _configAccessor.setClusterConfig(clusterName, clusterConfig);
+  }
+
+  private List<MockParticipantManager> addInstance(String instanceNameSuffix, 
String zone,
+      boolean enabled) {
+    List<MockParticipantManager> participants = new ArrayList<>();
+    String storageNodeName = PARTICIPANT_PREFIX + "_" + instanceNameSuffix;
+    _instanceNameZoneMap.put(storageNodeName, zone);
+    _gSetupTool.addInstanceToCluster(CLUSTER_NAME, storageNodeName);
+
+    String domain =
+        String.format("zone=%s,host=%s,logicalId=%s", zone, storageNodeName, 
instanceNameSuffix);
+    InstanceConfig instanceConfig =
+        _configAccessor.getInstanceConfig(CLUSTER_NAME, storageNodeName);
+    instanceConfig.setDomain(domain);
+    _gSetupTool.getClusterManagementTool()
+        .setInstanceConfig(CLUSTER_NAME, storageNodeName, instanceConfig);
+    MockParticipantManager participant =
+        new MockParticipantManager(ZK_ADDR, CLUSTER_NAME, storageNodeName);
+    if (enabled) {
+      // start dummy participant
+      participant.syncStart();
+    }
+    participants.add(participant);
+
+    return participants;
+  }
 }

Reply via email to