This is an automated email from the ASF dual-hosted git repository.

xyuanlu pushed a commit to branch helix-gateway-service
in repository https://gitbox.apache.org/repos/asf/helix.git

commit fd092b51b985fd3c5d9b8af12cab132dc6a38eb4
Author: xyuanlu <xyua...@gmail.com>
AuthorDate: Thu Sep 12 20:50:11 2024 -0700

    Gateway - gateway participant update target state in cache (#2910)
    
    Gateway - gateway participant update target state in cache
---
 .../channel/HelixGatewayServiceGrpcService.java    |  13 +--
 .../participant/HelixGatewayParticipant.java       |  79 +++++--------
 .../gateway/service/GatewayServiceManager.java     |  48 +++++---
 .../gateway/util/GatewayCurrentStateCache.java     | 126 +++++++++++----------
 .../util/StateTransitionMessageTranslateUtil.java  |  12 +-
 .../participant/TestHelixGatewayParticipant.java   |  75 ++++++++++--
 .../gateway/util/TestGatewayCurrentStateCache.java |  35 +++---
 7 files changed, 226 insertions(+), 162 deletions(-)

diff --git 
a/helix-gateway/src/main/java/org/apache/helix/gateway/channel/HelixGatewayServiceGrpcService.java
 
b/helix-gateway/src/main/java/org/apache/helix/gateway/channel/HelixGatewayServiceGrpcService.java
index a9c1ca4a2..6a2c76b8a 100644
--- 
a/helix-gateway/src/main/java/org/apache/helix/gateway/channel/HelixGatewayServiceGrpcService.java
+++ 
b/helix-gateway/src/main/java/org/apache/helix/gateway/channel/HelixGatewayServiceGrpcService.java
@@ -38,7 +38,6 @@ import 
org.apache.helix.gateway.util.StateTransitionMessageTranslateUtil;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import proto.org.apache.helix.gateway.HelixGatewayServiceGrpc;
-import proto.org.apache.helix.gateway.HelixGatewayServiceOuterClass;
 import 
proto.org.apache.helix.gateway.HelixGatewayServiceOuterClass.ShardChangeRequests;
 import proto.org.apache.helix.gateway.HelixGatewayServiceOuterClass.ShardState;
 import 
proto.org.apache.helix.gateway.HelixGatewayServiceOuterClass.ShardStateMessage;
@@ -92,7 +91,7 @@ public class HelixGatewayServiceGrpcService extends 
HelixGatewayServiceGrpc.Heli
           updateObserver(shardState.getInstanceName(), 
shardState.getClusterName(), responseObserver);
         }
         pushClientEventToGatewayManager(_manager,
-            
StateTransitionMessageTranslateUtil.translateShardStateMessageToEvent(request));
+            
StateTransitionMessageTranslateUtil.translateShardStateMessageToEventAndUpdateCache(_manager,
 request));
       }
 
       @Override
@@ -120,9 +119,8 @@ public class HelixGatewayServiceGrpcService extends 
HelixGatewayServiceGrpc.Heli
    */
   @Override
   public void sendStateChangeRequests(String instanceName, ShardChangeRequests 
requests) {
-    StreamObserver<HelixGatewayServiceOuterClass.ShardChangeRequests> observer;
-    observer = _observerMap.get(instanceName);
-    if (observer != null) {
+    StreamObserver<ShardChangeRequests> observer = 
_observerMap.get(instanceName);
+    if (observer!= null) {
       observer.onNext(requests);
     } else {
       logger.error("Instance {} is not connected to the gateway service", 
instanceName);
@@ -151,8 +149,7 @@ public class HelixGatewayServiceGrpcService extends 
HelixGatewayServiceGrpc.Heli
   }
 
   private void closeConnectionHelper(String instanceName, String errorReason, 
boolean withError) {
-    StreamObserver<ShardChangeRequests> observer;
-    observer = _observerMap.get(instanceName);
+    StreamObserver<ShardChangeRequests> observer = 
_observerMap.get(instanceName);
     if (observer != null) {
       if (withError) {
         
observer.onError(Status.UNAVAILABLE.withDescription(errorReason).asRuntimeException());
@@ -162,7 +159,7 @@ public class HelixGatewayServiceGrpcService extends 
HelixGatewayServiceGrpc.Heli
     }
   }
 
-  public void onClientClose(String clusterName, String instanceName) {
+   private void onClientClose(String clusterName, String instanceName) {
     if (instanceName == null || clusterName == null) {
       // TODO: log error;
       return;
diff --git 
a/helix-gateway/src/main/java/org/apache/helix/gateway/participant/HelixGatewayParticipant.java
 
b/helix-gateway/src/main/java/org/apache/helix/gateway/participant/HelixGatewayParticipant.java
index b17d897e1..8dd04644b 100644
--- 
a/helix-gateway/src/main/java/org/apache/helix/gateway/participant/HelixGatewayParticipant.java
+++ 
b/helix-gateway/src/main/java/org/apache/helix/gateway/participant/HelixGatewayParticipant.java
@@ -19,17 +19,15 @@ package org.apache.helix.gateway.participant;
  * under the License.
  */
 
-import com.google.common.annotations.VisibleForTesting;
 import java.util.ArrayList;
-import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ConcurrentHashMap;
-import org.apache.helix.HelixDefinedState;
 import org.apache.helix.HelixManager;
 import org.apache.helix.InstanceType;
 import org.apache.helix.gateway.api.service.HelixGatewayServiceChannel;
+import org.apache.helix.gateway.service.GatewayServiceManager;
 import 
org.apache.helix.gateway.statemodel.HelixGatewayMultiTopStateStateModelFactory;
 import org.apache.helix.gateway.util.StateTransitionMessageTranslateUtil;
 import org.apache.helix.manager.zk.HelixManagerStateListener;
@@ -37,6 +35,7 @@ import org.apache.helix.manager.zk.ZKHelixManager;
 import org.apache.helix.model.Message;
 import org.apache.helix.participant.statemachine.StateTransitionError;
 
+
 /**
  * HelixGatewayParticipant encapsulates the Helix Participant Manager and 
handles tracking the state
  * of a remote participant connected to the Helix Gateway Service. It 
processes state transitions
@@ -48,18 +47,18 @@ public class HelixGatewayParticipant implements 
HelixManagerStateListener {
   private final HelixGatewayServiceChannel _gatewayServiceChannel;
   private final HelixManager _helixManager;
   private final Runnable _onDisconnectedCallback;
-  private final Map<String, Map<String, String>> _shardStateMap;
-
   private final Map<String, CompletableFuture<String>> 
_stateTransitionResultMap;
 
-  private HelixGatewayParticipant(HelixGatewayServiceChannel 
gatewayServiceChannel,
-      Runnable onDisconnectedCallback, HelixManager helixManager,
-      Map<String, Map<String, String>> initialShardStateMap) {
+  private final GatewayServiceManager _gatewayServiceManager;
+
+  private HelixGatewayParticipant(HelixGatewayServiceChannel 
gatewayServiceChannel, Runnable onDisconnectedCallback,
+      HelixManager helixManager, Map<String, Map<String, String>> 
initialShardStateMap,
+      GatewayServiceManager gatewayServiceManager) {
     _gatewayServiceChannel = gatewayServiceChannel;
     _helixManager = helixManager;
     _onDisconnectedCallback = onDisconnectedCallback;
-    _shardStateMap = initialShardStateMap;
     _stateTransitionResultMap = new ConcurrentHashMap<>();
+    _gatewayServiceManager = gatewayServiceManager;
   }
 
   public void processStateTransitionMessage(Message message) throws Exception {
@@ -69,12 +68,14 @@ public class HelixGatewayParticipant implements 
HelixManagerStateListener {
     String concatenatedShardName = resourceId + shardId;
 
     try {
+      // update the target state in cache
+      _gatewayServiceManager.updateTargetState(_helixManager.getClusterName(), 
_helixManager.getInstanceName(),
+          resourceId, shardId, toState);
+
       if (isCurrentStateAlreadyTarget(resourceId, shardId, toState)) {
         return;
       }
-
       CompletableFuture<String> future = new CompletableFuture<>();
-
       _stateTransitionResultMap.put(concatenatedShardName, future);
       
_gatewayServiceChannel.sendStateChangeRequests(_helixManager.getInstanceName(),
           
StateTransitionMessageTranslateUtil.translateSTMsgToShardChangeRequests(message));
@@ -82,8 +83,6 @@ public class HelixGatewayParticipant implements 
HelixManagerStateListener {
       if (!toState.equals(future.get())) {
         throw new Exception("Failed to transition to state " + toState);
       }
-
-      updateState(resourceId, shardId, toState);
     } finally {
       _stateTransitionResultMap.remove(concatenatedShardName);
     }
@@ -92,24 +91,18 @@ public class HelixGatewayParticipant implements 
HelixManagerStateListener {
   public void handleStateTransitionError(Message message, StateTransitionError 
error) {
     // Remove the stateTransitionResultMap future for the message
     String transitionId = message.getMsgId();
-    String resourceId = message.getResourceName();
-    String shardId = message.getPartitionName();
 
     // Remove the future from the stateTransitionResultMap since we are no 
longer able
     // to process the state transition due to participant manager either 
timing out
     // or failing to process the state transition
     _stateTransitionResultMap.remove(transitionId);
 
-    // Set the replica state to ERROR
-    updateState(resourceId, shardId, HelixDefinedState.ERROR.name());
-
     // Notify the HelixGatewayParticipantClient that it is in ERROR state
     // TODO: We need a better way than sending the state transition with a 
toState of ERROR
   }
 
   /**
    * Get the instance name of the participant.
-   *
    * @return participant instance name
    */
   public String getInstanceName() {
@@ -118,7 +111,6 @@ public class HelixGatewayParticipant implements 
HelixManagerStateListener {
 
   /**
    * Completes the state transition with the given transitionId.
-   *
    */
   public void completeStateTransition(String resourceId, String shardId, 
String currentState) {
     String concatenatedShardName = resourceId + shardId;
@@ -128,16 +120,10 @@ public class HelixGatewayParticipant implements 
HelixManagerStateListener {
     }
   }
 
-  private boolean isCurrentStateAlreadyTarget(String resourceId, String 
shardId,
-      String targetState) {
+  private boolean isCurrentStateAlreadyTarget(String resourceId, String 
shardId, String targetState) {
     return getCurrentState(resourceId, shardId).equals(targetState);
   }
 
-  @VisibleForTesting
-  Map<String, Map<String, String>> getShardStateMap() {
-    return _shardStateMap;
-  }
-
   /**
    * Get the current state of the shard.
    *
@@ -146,23 +132,10 @@ public class HelixGatewayParticipant implements 
HelixManagerStateListener {
    * @return the current state of the shard or DROPPED if it does not exist
    */
   public String getCurrentState(String resourceId, String shardId) {
-    return getShardStateMap().getOrDefault(resourceId, Collections.emptyMap())
-        .getOrDefault(shardId, UNASSIGNED_STATE);
-  }
-
-  private void updateState(String resourceId, String shardId, String state) {
-    if (state.equals(HelixDefinedState.DROPPED.name())) {
-      getShardStateMap().computeIfPresent(resourceId, (k, v) -> {
-        v.remove(shardId);
-        if (v.isEmpty()) {
-          return null;
-        }
-        return v;
-      });
-    } else {
-      getShardStateMap().computeIfAbsent(resourceId, k -> new 
ConcurrentHashMap<>())
-          .put(shardId, state);
-    }
+    String currentState =
+        _gatewayServiceManager.getCurrentState(_helixManager.getClusterName(), 
_helixManager.getInstanceName(),
+            resourceId, shardId);
+    return currentState == null ? UNASSIGNED_STATE : currentState;
   }
 
   /**
@@ -203,13 +176,15 @@ public class HelixGatewayParticipant implements 
HelixManagerStateListener {
     private final Runnable _onDisconnectedCallback;
     private final List<String> _multiTopStateModelDefinitions;
     private final Map<String, Map<String, String>> _initialShardStateMap;
+    private final GatewayServiceManager _gatewayServiceManager;
 
-    public Builder(HelixGatewayServiceChannel helixGatewayServiceChannel, 
String instanceName,
-        String clusterName, String zkAddress, Runnable onDisconnectedCallback) 
{
+    public Builder(HelixGatewayServiceChannel helixGatewayServiceChannel, 
String instanceName, String clusterName,
+        String zkAddress, Runnable onDisconnectedCallback, 
GatewayServiceManager gatewayServiceManager) {
       _helixGatewayServiceChannel = helixGatewayServiceChannel;
       _instanceName = instanceName;
       _clusterName = clusterName;
       _zkAddress = zkAddress;
+      _gatewayServiceManager = gatewayServiceManager;
       _onDisconnectedCallback = onDisconnectedCallback;
       _multiTopStateModelDefinitions = new ArrayList<>();
       _initialShardStateMap = new ConcurrentHashMap<>();
@@ -259,13 +234,11 @@ public class HelixGatewayParticipant implements 
HelixManagerStateListener {
       HelixManager participantManager =
           new ZKHelixManager(_clusterName, _instanceName, 
InstanceType.PARTICIPANT, _zkAddress);
       HelixGatewayParticipant participant =
-          new HelixGatewayParticipant(_helixGatewayServiceChannel, 
_onDisconnectedCallback,
-              participantManager,
-              _initialShardStateMap);
-      _multiTopStateModelDefinitions.forEach(
-          stateModelDefinition -> participantManager.getStateMachineEngine()
-              .registerStateModelFactory(stateModelDefinition,
-                  new 
HelixGatewayMultiTopStateStateModelFactory(participant)));
+          new HelixGatewayParticipant(_helixGatewayServiceChannel, 
_onDisconnectedCallback, participantManager,
+              _initialShardStateMap, _gatewayServiceManager);
+      _multiTopStateModelDefinitions.forEach(stateModelDefinition -> 
participantManager.getStateMachineEngine()
+          .registerStateModelFactory(stateModelDefinition,
+              new HelixGatewayMultiTopStateStateModelFactory(participant)));
       try {
         participantManager.connect();
       } catch (Exception e) {
diff --git 
a/helix-gateway/src/main/java/org/apache/helix/gateway/service/GatewayServiceManager.java
 
b/helix-gateway/src/main/java/org/apache/helix/gateway/service/GatewayServiceManager.java
index 94f5783f0..e4d207fd7 100644
--- 
a/helix-gateway/src/main/java/org/apache/helix/gateway/service/GatewayServiceManager.java
+++ 
b/helix-gateway/src/main/java/org/apache/helix/gateway/service/GatewayServiceManager.java
@@ -108,29 +108,48 @@ public class GatewayServiceManager {
     }
   }
 
-  private GatewayCurrentStateCache getCache(String clusterName) {
-    return _currentStateCacheMap.computeIfAbsent(clusterName, k -> new 
GatewayCurrentStateCache(clusterName));
-  }
-
   public void resetTargetStateCache(String clusterName, String instanceName) {
-    getCache(clusterName).resetTargetStateCache(instanceName);
+    getOrCreateCache(clusterName).resetTargetStateCache(instanceName);
   }
 
+  /**
+   * Overwrite the current state cache with the new current state map, and 
return the diff of the change.
+   * @param clusterName
+   * @param newCurrentStateMap
+   * @return
+   */
   public  Map<String, Map<String, Map<String, String>>> 
updateCacheWithNewCurrentStateAndGetDiff(String clusterName,
       Map<String, Map<String, Map<String, String>>> newCurrentStateMap) {
-   return  
getCache(clusterName).updateCacheWithNewCurrentStateAndGetDiff(newCurrentStateMap);
+   return  
getOrCreateCache(clusterName).updateCacheWithNewCurrentStateAndGetDiff(newCurrentStateMap);
+  }
+
+  public void updateCurrentState(String clusterName, String instanceName, 
String resourceId, String shardId, String toState) {
+    
getOrCreateCache(clusterName).updateCurrentStateOfExistingInstance(instanceName,
 resourceId, shardId, toState);
   }
 
-  public String serializeTargetState() {
+  public synchronized String serializeTargetState() {
     ObjectNode targetStateNode = new ObjectMapper().createObjectNode();
     for (String clusterName : _currentStateCacheMap.keySet()) {
       // add the json node to the target state node
-      targetStateNode.set(clusterName, 
getCache(clusterName).serializeTargetAssignmentsToJSONNode());
+      targetStateNode.set(clusterName, 
getOrCreateCache(clusterName).serializeTargetAssignmentsToJSONNode());
     }
     targetStateNode.set("timestamp", 
objectMapper.valueToTree(System.currentTimeMillis()));
     return targetStateNode.toString();
   }
 
+  public void updateTargetState(String clusterName, String instanceName, 
String resourceId, String shardId,
+      String toState) {
+    
getOrCreateCache(clusterName).updateTargetStateOfExistingInstance(instanceName, 
resourceId, shardId, toState);
+  }
+
+  public String getCurrentState(String clusterName, String instanceName, 
String resourceId, String shardId) {
+    return getOrCreateCache(clusterName).getCurrentState(instanceName, 
resourceId, shardId);
+  }
+
+  public String getTargetState(String clusterName, String instanceName, String 
resourceId, String shardId) {
+    return getOrCreateCache(clusterName).getTargetState(instanceName, 
resourceId, shardId);
+  }
+
   /**
    * Update in memory shard state
    */
@@ -199,12 +218,10 @@ public class GatewayServiceManager {
     resetTargetStateCache(clusterName, instanceName);
     // Create and add the participant to the participant map
     HelixGatewayParticipant.Builder participantBuilder =
-        new HelixGatewayParticipant.Builder(_gatewayServiceChannel, 
instanceName, clusterName,
-            _zkAddress,
-            () -> removeHelixGatewayParticipant(clusterName, 
instanceName)).setInitialShardState(
+        new HelixGatewayParticipant.Builder(_gatewayServiceChannel, 
instanceName, clusterName, _zkAddress,
+            () -> removeHelixGatewayParticipant(clusterName, instanceName), 
this).setInitialShardState(
             initialShardStateMap);
-    SUPPORTED_MULTI_STATE_MODEL_TYPES.forEach(
-        participantBuilder::addMultiTopStateStateModelDefinition);
+    
SUPPORTED_MULTI_STATE_MODEL_TYPES.forEach(participantBuilder::addMultiTopStateStateModelDefinition);
     _helixGatewayParticipantMap.computeIfAbsent(clusterName, k -> new 
ConcurrentHashMap<>())
         .put(instanceName, participantBuilder.build());
   }
@@ -216,6 +233,7 @@ public class GatewayServiceManager {
       participant.disconnect();
       _helixGatewayParticipantMap.get(clusterName).remove(instanceName);
     }
+    
_currentStateCacheMap.get(clusterName).removeInstanceTargetDataFromCache(instanceName);
   }
 
   private HelixGatewayParticipant getHelixGatewayParticipant(String 
clusterName,
@@ -223,4 +241,8 @@ public class GatewayServiceManager {
     return _helixGatewayParticipantMap.getOrDefault(clusterName, 
Collections.emptyMap())
         .get(instanceName);
   }
+
+  private synchronized GatewayCurrentStateCache getOrCreateCache(String 
clusterName) {
+    return _currentStateCacheMap.computeIfAbsent(clusterName, k -> new 
GatewayCurrentStateCache(clusterName));
+  }
 }
diff --git 
a/helix-gateway/src/main/java/org/apache/helix/gateway/util/GatewayCurrentStateCache.java
 
b/helix-gateway/src/main/java/org/apache/helix/gateway/util/GatewayCurrentStateCache.java
index a3ba7fbe7..2b8b1c978 100644
--- 
a/helix-gateway/src/main/java/org/apache/helix/gateway/util/GatewayCurrentStateCache.java
+++ 
b/helix-gateway/src/main/java/org/apache/helix/gateway/util/GatewayCurrentStateCache.java
@@ -23,14 +23,17 @@ import com.fasterxml.jackson.databind.ObjectMapper;
 import com.fasterxml.jackson.databind.node.ObjectNode;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 
 /**
  * A cache to store the current target assignment, and the reported current 
state of the instances in a cluster.
  */
 public class GatewayCurrentStateCache {
-  static ObjectMapper mapper = new ObjectMapper();
-  String _clusterName;
+  private static final Logger logger = 
LoggerFactory.getLogger(GatewayCurrentStateCache.class);
+  private static final ObjectMapper mapper = new ObjectMapper();
 
   // A cache of current state. It should be updated by the 
HelixGatewayServiceChannel
   // instance -> resource state (resource -> shard -> target state)
@@ -38,120 +41,123 @@ public class GatewayCurrentStateCache {
 
   // A cache of target state.
   // instance -> resource state (resource -> shard -> target state)
-  Map<String, ShardStateMap> _targetStateMap;
+  final Map<String, ShardStateMap> _targetStateMap;
+  ObjectNode root = mapper.createObjectNode();
 
   public GatewayCurrentStateCache(String clusterName) {
-    _clusterName = clusterName;
     _currentStateMap = new HashMap<>();
     _targetStateMap = new HashMap<>();
   }
 
   public String getCurrentState(String instance, String resource, String 
shard) {
-    return _currentStateMap.get(instance).getState(resource, shard);
+    ShardStateMap shardStateMap = _currentStateMap.get(instance);
+    return shardStateMap == null ? null : shardStateMap.getState(resource, 
shard);
   }
 
   public String getTargetState(String instance, String resource, String shard) 
{
-    return _targetStateMap.get(instance).getState(resource, shard);
+    ShardStateMap shardStateMap = _targetStateMap.get(instance);
+    return shardStateMap == null ? null : shardStateMap.getState(resource, 
shard);
   }
 
   /**
    * Update the cached current state of instances in a cluster, and return the 
diff of the change.
-   * @param newCurrentStateMap The new current state map of instances in the 
cluster
+   * @param userCurrentStateMap The new current state map of instances in the 
cluster
    * @return
    */
   public Map<String, Map<String, Map<String, String>>> 
updateCacheWithNewCurrentStateAndGetDiff(
-      Map<String, Map<String, Map<String, String>>> newCurrentStateMap) {
+      Map<String, Map<String, Map<String, String>>> userCurrentStateMap) {
+    Map<String, ShardStateMap> newCurrentStateMap = new 
HashMap<>(_currentStateMap);
     Map<String, Map<String, Map<String, String>>> diff = new HashMap<>();
-    for (String instance : newCurrentStateMap.keySet()) {
-      Map<String, Map<String, String>> newCurrentState = 
newCurrentStateMap.get(instance);
-      Map<String, Map<String, String>> resourceStateDiff =
-          _currentStateMap.computeIfAbsent(instance, k -> new 
ShardStateMap(new HashMap<>()))
-              .updateAndGetDiff(newCurrentState);
-      if (resourceStateDiff != null && !resourceStateDiff.isEmpty()) {
-        diff.put(instance, resourceStateDiff);
+    for (String instance : userCurrentStateMap.keySet()) {
+      ShardStateMap oldStateMap = _currentStateMap.get(instance);
+      Map<String, Map<String, String>> instanceDiff = oldStateMap == null ? 
userCurrentStateMap.get(instance)
+          : oldStateMap.getDiff(userCurrentStateMap.get(instance));
+      if (!instanceDiff.isEmpty()) {
+        diff.put(instance, instanceDiff);
       }
+      newCurrentStateMap.put(instance, new 
ShardStateMap(userCurrentStateMap.get(instance)));
     }
+    logger.info("Update current state cache for instances: {}", diff.keySet());
+    _currentStateMap = newCurrentStateMap;
     return diff;
   }
 
   /**
-   * Update the cache with the current state diff.
-   * All existing target states remains the same
-   * @param currentStateDiff
+   * Update the current state with the changed current state maps.
    */
-  public void updateCacheWithCurrentStateDiff(Map<String, Map<String, 
Map<String, String>>> currentStateDiff) {
-    for (String instance : currentStateDiff.keySet()) {
-      Map<String, Map<String, String>> currentStateDiffMap = 
currentStateDiff.get(instance);
-      updateShardStateMapWithDiff(_currentStateMap, instance, 
currentStateDiffMap);
-    }
+  public void updateCurrentStateOfExistingInstance(String instance, String 
resource, String shard, String shardState) {
+    logger.info("Update current state of instance: {}, resource: {}, shard: 
{}, state: {}", instance, resource, shard,
+        shardState);
+    updateShardStateMapWithDiff(_currentStateMap, instance, resource, shard, 
shardState);
   }
 
   /**
    * Update the target state with the changed target state maps.
    * All existing target states remains the same
-   * @param targetStateChangeMap
    */
-  public void updateTargetStateWithDiff(String instance, Map<String, 
Map<String, String>> targetStateChangeMap) {
-    updateShardStateMapWithDiff(_targetStateMap, instance, 
targetStateChangeMap);
+  public void updateTargetStateOfExistingInstance(String instance, String 
resource, String shard, String shardState) {
+    logger.info("Update target state of instance: {}, resource: {}, shard: {}, 
state: {}", instance, resource, shard,
+        shardState);
+    updateShardStateMapWithDiff(_targetStateMap, instance, resource, shard, 
shardState);
+  }
+
+  private void updateShardStateMapWithDiff(Map<String, ShardStateMap> 
stateMap, String instance, String resource,
+      String shard, String shardState) {
+    ShardStateMap curStateMap = stateMap.get(instance);
+    if (curStateMap == null) {
+      logger.warn("Instance {} is not in the state map, skip updating", 
instance);
+      return;
+    }
+    curStateMap.updateWithShardState(resource, shard, shardState);
   }
 
   /**
    * Serialize the target state assignments to a JSON Node.
    * example : 
{"instance1":{"resource1":{"shard1":"ONLINE","shard2":"OFFLINE"}}}}
    */
-  public ObjectNode serializeTargetAssignmentsToJSONNode() {
-    ObjectNode root = mapper.createObjectNode();
+  public synchronized ObjectNode serializeTargetAssignmentsToJSONNode() {
     for (Map.Entry<String, ShardStateMap> entry : _targetStateMap.entrySet()) {
       root.set(entry.getKey(), entry.getValue().toJSONNode());
     }
     return root;
   }
 
-  private void updateShardStateMapWithDiff(Map<String, ShardStateMap> 
stateMap, String instance,
-      Map<String, Map<String, String>> diffMap) {
-    if (diffMap == null || diffMap.isEmpty()) {
-      return;
-    }
-    stateMap.computeIfAbsent(instance, k -> new ShardStateMap(new 
HashMap<>())).updateWithDiff(diffMap);
+  /**
+   * Remove the target state data of an instance from the cache.
+   */
+  public synchronized void removeInstanceTargetDataFromCache(String instance) {
+    logger.info("Remove instance target data from cache for instance: {}", 
instance);
+    _targetStateMap.remove(instance);
+    root.remove(instance);
   }
 
-  public void resetTargetStateCache(String instance) {
+  /**
+   * Remove the current state data of an instance from the cache to an empty 
map.
+   */
+  public synchronized void resetTargetStateCache(String instance) {
+    logger.info("Reset target state cache for instance: {}", instance);
     _targetStateMap.put(instance, new ShardStateMap(new HashMap<>()));
   }
 
   public static class ShardStateMap {
     Map<String, Map<String, String>> _stateMap;
+    ObjectNode root = mapper.createObjectNode();
 
     public ShardStateMap(Map<String, Map<String, String>> stateMap) {
-      _stateMap = stateMap;
+      _stateMap = new HashMap<>(stateMap);
     }
 
-    public String getState(String instance, String shard) {
-      return _stateMap.get(instance).get(shard);
+    public String getState(String resource, String shard) {
+      Map<String, String> shardStateMap = _stateMap.get(resource);
+      return shardStateMap == null ? null : shardStateMap.get(shard);
     }
 
-    private Map<String, Map<String, String>> getShardStateMap() {
-      return _stateMap;
-    }
-
-    private void updateWithDiff(Map<String, Map<String, String>> diffMap) {
-      for (Map.Entry<String, Map<String, String>> diffEntry : 
diffMap.entrySet()) {
-        String resource = diffEntry.getKey();
-        Map<String, String> diffCurrentState = diffEntry.getValue();
-        if (_stateMap.get(resource) != null) {
-          _stateMap.get(resource).entrySet().forEach(currentMapEntry -> {
-            String shard = currentMapEntry.getKey();
-            if (diffCurrentState.get(shard) != null) {
-              currentMapEntry.setValue(diffCurrentState.get(shard));
-            }
-          });
-        } else {
-          _stateMap.put(resource, diffCurrentState);
-        }
-      }
+    public synchronized void updateWithShardState(String resource, String 
shard, String shardState) {
+      logger.info("Update ShardStateMap of resource: {}, shard: {}, state: 
{}", resource, shard, shardState);
+      _stateMap.computeIfAbsent(resource, k -> new HashMap<>()).put(shard, 
shardState);
     }
 
-    private Map<String, Map<String, String>> updateAndGetDiff(Map<String, 
Map<String, String>> newCurrentStateMap) {
+    private Map<String, Map<String, String>> getDiff(Map<String, Map<String, 
String>> newCurrentStateMap) {
       Map<String, Map<String, String>> diff = new HashMap<>();
       for (Map.Entry<String, Map<String, String>> entry : 
newCurrentStateMap.entrySet()) {
         String resource = entry.getKey();
@@ -169,7 +175,6 @@ public class GatewayCurrentStateCache {
           }
         }
       }
-      _stateMap = newCurrentStateMap;
       return diff;
     }
 
@@ -177,8 +182,7 @@ public class GatewayCurrentStateCache {
      * Serialize the shard state map to a JSON object.
      * @return a JSON object representing the shard state map. Example: 
{"shard1":"ONLINE","shard2":"OFFLINE"}
      */
-    public ObjectNode toJSONNode() {
-      ObjectNode root = mapper.createObjectNode();
+    public synchronized ObjectNode toJSONNode() {
       for (Map.Entry<String, Map<String, String>> entry : 
_stateMap.entrySet()) {
         String resource = entry.getKey();
         ObjectNode resourceNode = mapper.createObjectNode();
diff --git 
a/helix-gateway/src/main/java/org/apache/helix/gateway/util/StateTransitionMessageTranslateUtil.java
 
b/helix-gateway/src/main/java/org/apache/helix/gateway/util/StateTransitionMessageTranslateUtil.java
index f9f2b4c3d..a2d07085b 100644
--- 
a/helix-gateway/src/main/java/org/apache/helix/gateway/util/StateTransitionMessageTranslateUtil.java
+++ 
b/helix-gateway/src/main/java/org/apache/helix/gateway/util/StateTransitionMessageTranslateUtil.java
@@ -27,6 +27,7 @@ import org.apache.helix.HelixDefinedState;
 import org.apache.helix.gateway.api.constant.GatewayServiceEventType;
 import org.apache.helix.gateway.participant.HelixGatewayParticipant;
 import org.apache.helix.gateway.service.GatewayServiceEvent;
+import org.apache.helix.gateway.service.GatewayServiceManager;
 import org.apache.helix.model.Message;
 import proto.org.apache.helix.gateway.HelixGatewayServiceOuterClass;
 import 
proto.org.apache.helix.gateway.HelixGatewayServiceOuterClass.ShardChangeRequests;
@@ -79,7 +80,8 @@ public final class StateTransitionMessageTranslateUtil {
    *                contains the state of each shard upon connection or result 
of state transition request.
    * @return GatewayServiceEvent
    */
-  public static GatewayServiceEvent 
translateShardStateMessageToEvent(ShardStateMessage request) {
+  public static GatewayServiceEvent 
translateShardStateMessageToEventAndUpdateCache(
+      GatewayServiceManager manager, ShardStateMessage request) {
     GatewayServiceEvent.GateWayServiceEventBuilder builder;
     if (request.hasShardState()) { // init connection to gateway service
       ShardState shardState = request.getShardState();
@@ -90,6 +92,11 @@ public final class StateTransitionMessageTranslateUtil {
               .put(state.getShardName(), state.getCurrentState());
         }
       }
+      // update current state cache. We always overwrite the current state map 
for initial connection
+      Map<String, Map<String, Map<String, String>>> newShardStateMap = new 
HashMap<>();
+      newShardStateMap.put(shardState.getInstanceName(), shardStateMap);
+      
manager.updateCacheWithNewCurrentStateAndGetDiff(shardState.getClusterName(), 
newShardStateMap);
+
       builder = new 
GatewayServiceEvent.GateWayServiceEventBuilder(GatewayServiceEventType.CONNECT).setClusterName(
               
shardState.getClusterName()).setParticipantName(shardState.getInstanceName())
           .setShardStateMap(shardStateMap);
@@ -103,6 +110,9 @@ public final class StateTransitionMessageTranslateUtil {
         GatewayServiceEvent.StateTransitionResult result =
             new 
GatewayServiceEvent.StateTransitionResult(shardTransition.getResourceName(),
                 shardTransition.getShardName(), 
shardTransition.getCurrentState());
+        // update current state cache
+        manager.updateCurrentState(shardTransitionStatus.getClusterName(), 
shardTransitionStatus.getInstanceName(),
+            shardTransition.getResourceName(), shardTransition.getShardName(), 
shardTransition.getCurrentState());
         stResult.add(result);
       }
       builder = new 
GatewayServiceEvent.GateWayServiceEventBuilder(GatewayServiceEventType.UPDATE).setClusterName(
diff --git 
a/helix-gateway/src/test/java/org/apache/helix/gateway/participant/TestHelixGatewayParticipant.java
 
b/helix-gateway/src/test/java/org/apache/helix/gateway/participant/TestHelixGatewayParticipant.java
index 75e9fc581..128a22857 100644
--- 
a/helix-gateway/src/test/java/org/apache/helix/gateway/participant/TestHelixGatewayParticipant.java
+++ 
b/helix-gateway/src/test/java/org/apache/helix/gateway/participant/TestHelixGatewayParticipant.java
@@ -21,6 +21,7 @@ package org.apache.helix.gateway.participant;
 
 import java.io.IOException;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
@@ -32,6 +33,8 @@ import org.apache.helix.ConfigAccessor;
 import org.apache.helix.TestHelper;
 import org.apache.helix.common.ZkTestBase;
 import org.apache.helix.gateway.api.service.HelixGatewayServiceChannel;
+import org.apache.helix.gateway.channel.GatewayServiceChannelConfig;
+import org.apache.helix.gateway.service.GatewayServiceManager;
 import org.apache.helix.integration.manager.ClusterControllerManager;
 import org.apache.helix.manager.zk.ZKHelixManager;
 import org.apache.helix.model.ClusterConfig;
@@ -62,8 +65,17 @@ public class TestHelixGatewayParticipant extends ZkTestBase {
   private final Map<String, ShardChangeRequests> _pendingMessageMap = new 
ConcurrentHashMap<>();
   private final AtomicInteger _onDisconnectCallbackCount = new AtomicInteger();
 
+  private GatewayServiceManager _gatewayServiceManager;
+
   @BeforeClass
   public void beforeClass() {
+    GatewayServiceChannelConfig.GatewayServiceProcessorConfigBuilder builder =
+        new GatewayServiceChannelConfig.GatewayServiceProcessorConfigBuilder();
+
+    
builder.setParticipantConnectionChannelType(GatewayServiceChannelConfig.ChannelType.GRPC_SERVER).setGrpcServerPort(5001);
+    GatewayServiceChannelConfig config = builder.build();
+    _gatewayServiceManager = new GatewayServiceManager(ZK_ADDR, config);
+
     // Set up the Helix cluster
     ConfigAccessor configAccessor = new ConfigAccessor(_gZkClient);
     _gSetupTool.addCluster(CLUSTER_NAME, true);
@@ -100,9 +112,10 @@ public class TestHelixGatewayParticipant extends 
ZkTestBase {
    */
   private HelixGatewayParticipant addParticipant(String participantName,
       Map<String, Map<String, String>> initialShardMap) {
+    _gatewayServiceManager.resetTargetStateCache(CLUSTER_NAME, 
participantName);
     HelixGatewayParticipant participant =
         new HelixGatewayParticipant.Builder(new 
MockHelixGatewayServiceChannel(_pendingMessageMap), participantName,
-            CLUSTER_NAME, ZK_ADDR, 
_onDisconnectCallbackCount::incrementAndGet).addMultiTopStateStateModelDefinition(
+            CLUSTER_NAME, ZK_ADDR, 
_onDisconnectCallbackCount::incrementAndGet, 
_gatewayServiceManager).addMultiTopStateStateModelDefinition(
             TEST_STATE_MODEL).setInitialShardState(initialShardMap).build();
     _participants.add(participant);
     return participant;
@@ -175,8 +188,26 @@ public class TestHelixGatewayParticipant extends 
ZkTestBase {
   private void processPendingMessage(HelixGatewayParticipant participant, 
boolean isSuccess, String toState) {
     ShardChangeRequests requests = 
_pendingMessageMap.remove(participant.getInstanceName());
 
-    
participant.completeStateTransition(requests.getRequest(0).getResourceName(),requests.getRequest(0).getShardName(),
-        isSuccess ? toState : "WRONG_STATE");
+    Map<String, Map<String, Map<String, String>>> newSInstanceStateMap = new 
HashMap<>();
+    newSInstanceStateMap.put(participant.getInstanceName(),
+        createSingleShardStateMap(requests.getRequest(0).getResourceName(), 
requests.getRequest(0).getShardName(),
+            isSuccess ? toState : "ERROR"));
+    
_gatewayServiceManager.updateCacheWithNewCurrentStateAndGetDiff(CLUSTER_NAME, 
newSInstanceStateMap);
+
+    
participant.completeStateTransition(requests.getRequest(0).getResourceName(), 
requests.getRequest(0).getShardName(),
+        isSuccess ? toState : "ERROR");
+  }
+
+  /**
+   * Create a single shard state map.
+   */
+   Map<String, Map<String, String>> createSingleShardStateMap( String 
resource, String shard, String state) {
+
+    Map<String, Map<String, String>> resourceStateMap = new HashMap<>();
+    Map<String, String> shardStateMap = new HashMap<>();
+    shardStateMap.put(shard, state);
+    resourceStateMap.put(resource, shardStateMap);
+    return resourceStateMap;
   }
 
   /**
@@ -209,7 +240,9 @@ public class TestHelixGatewayParticipant extends ZkTestBase 
{
             .getResourceIdealState(CLUSTER_NAME, resourceName)
             .getPartitionSet()) {
           String helixCurrentState = getHelixCurrentState(instanceName, 
resourceName, shardId);
-          if (!participant.getCurrentState(resourceName, 
shardId).equals(helixCurrentState)) {
+          if (!participant.getCurrentState(resourceName, 
shardId).equals(helixCurrentState) && !(
+              participant.getCurrentState(resourceName, 
shardId).equals("DROPPED") && helixCurrentState.equals(
+                  "UNASSIGNED"))) {
             return false;
           }
         }
@@ -218,6 +251,28 @@ public class TestHelixGatewayParticipant extends 
ZkTestBase {
     }), TestHelper.WAIT_DURATION));
   }
 
+  private void verifyGatewayTargetStateMatchHelixTargetState() throws 
Exception {
+    Assert.assertTrue(TestHelper.verify(() -> 
_participants.stream().allMatch(participant -> {
+      String instanceName = participant.getInstanceName();
+      for (String resourceName : 
_gSetupTool.getClusterManagementTool().getResourcesInCluster(CLUSTER_NAME)) {
+        for (String shardId : _gSetupTool.getClusterManagementTool()
+            .getResourceIdealState(CLUSTER_NAME, resourceName)
+            .getPartitionSet()) {
+          String helixTargetState = getHelixCurrentState(instanceName, 
resourceName, shardId);
+          if (_gatewayServiceManager.getTargetState(CLUSTER_NAME, 
instanceName, resourceName, shardId) == null) {
+            System.out.println("Gateway target state is null for instance: " + 
instanceName + ", resource: " + resourceName + ", shard: " + shardId);
+          }
+          if (!participant.getCurrentState(resourceName, 
shardId).equals(helixTargetState) && !(
+              participant.getCurrentState(resourceName, 
shardId).equals("DROPPED") && helixTargetState.equals(
+                  "UNASSIGNED"))) {
+            return false;
+          }
+        }
+      }
+      return true;
+    }), 6000L));
+  }
+
   /**
    * Verify that all shards for a given instance are in a specific state.
    */
@@ -256,6 +311,7 @@ public class TestHelixGatewayParticipant extends ZkTestBase 
{
     // Verify that the cluster converges and all states are "ONLINE"
     Assert.assertTrue(_clusterVerifier.verify());
     verifyGatewayStateMatchesHelixState();
+    verifyGatewayTargetStateMatchHelixTargetState();
   }
 
   @Test(dependsOnMethods = "testProcessStateTransitionMessageSuccess")
@@ -293,7 +349,9 @@ public class TestHelixGatewayParticipant extends ZkTestBase 
{
     verifyHelixPartitionStates(participant.getInstanceName(), 
HelixGatewayParticipant.UNASSIGNED_STATE);
 
     // Re-add the participant with its initial state
-    addParticipant(participant.getInstanceName(), 
participant.getShardStateMap());
+    addParticipant(participant.getInstanceName(),  
createSingleShardStateMap(TEST_DB, "TestDB_0",
+        _gatewayServiceManager.getCurrentState(CLUSTER_NAME, 
participant.getInstanceName(), TEST_DB,
+            "TestDB_0")));
     Assert.assertTrue(_clusterVerifier.verify());
 
     // Verify the Helix state is "ONLINE"
@@ -309,8 +367,10 @@ public class TestHelixGatewayParticipant extends 
ZkTestBase {
 
     // Remove shard preference and re-add the participant
     removeFromPreferenceList(participant);
-    HelixGatewayParticipant participantReplacement =
-        addParticipant(participant.getInstanceName(), 
participant.getShardStateMap());
+    HelixGatewayParticipant participantReplacement = 
addParticipant(participant.getInstanceName(),
+        createSingleShardStateMap(TEST_DB, "TestDB_0",
+            _gatewayServiceManager.getCurrentState(CLUSTER_NAME, 
participant.getInstanceName(), TEST_DB,
+                "TestDB_0")));
     verifyPendingMessages(List.of(participantReplacement));
 
     // Process the pending message successfully
@@ -319,6 +379,7 @@ public class TestHelixGatewayParticipant extends ZkTestBase 
{
     // Verify that the cluster converges and states are correctly updated to 
"ONLINE"
     Assert.assertTrue(_clusterVerifier.verify());
     verifyGatewayStateMatchesHelixState();
+    verifyGatewayTargetStateMatchHelixTargetState();
   }
 
   @Test(dependsOnMethods = 
"testProcessStateTransitionAfterReconnectAfterDroppingPartition")
diff --git 
a/helix-gateway/src/test/java/org/apache/helix/gateway/util/TestGatewayCurrentStateCache.java
 
b/helix-gateway/src/test/java/org/apache/helix/gateway/util/TestGatewayCurrentStateCache.java
index 99fec4a25..17aa25637 100644
--- 
a/helix-gateway/src/test/java/org/apache/helix/gateway/util/TestGatewayCurrentStateCache.java
+++ 
b/helix-gateway/src/test/java/org/apache/helix/gateway/util/TestGatewayCurrentStateCache.java
@@ -51,31 +51,28 @@ public class TestGatewayCurrentStateCache {
   }
 
   @Test
-  public void testUpdateCacheWithCurrentStateDiff() {
-    Map<String, Map<String, Map<String, String>>> diff = new HashMap<>();
+  public void testUpdateCacheWithExistingStateAndGetDiff() {
+    // Initial state
+    Map<String, Map<String, Map<String, String>>> initialState = new 
HashMap<>();
     Map<String, Map<String, String>> instanceState = new HashMap<>();
     Map<String, String> shardState = new HashMap<>();
-    shardState.put("shard2", "ONLINE");
     shardState.put("shard1", "ONLINE");
     instanceState.put("resource1", shardState);
-    diff.put("instance1", instanceState);
+    initialState.put("instance1", instanceState);
+    cache.updateCacheWithNewCurrentStateAndGetDiff(initialState);
 
-    cache.updateCacheWithCurrentStateDiff(diff);
-
-    Assert.assertEquals(cache.getCurrentState("instance1", "resource1", 
"shard1"), "ONLINE");
-    Assert.assertEquals(cache.getCurrentState("instance1", "resource1", 
"shard2"), "ONLINE");
-  }
-
-  @Test
-  public void testUpdateTargetStateWithDiff() {
-    Map<String, Map<String, String>> targetStateChange = new HashMap<>();
-    Map<String, String> shardState = new HashMap<>();
-    shardState.put("shard1", "OFFLINE");
-    targetStateChange.put("resource1", shardState);
+    // New state with a change
+    Map<String, Map<String, Map<String, String>>> newState = new HashMap<>();
+    Map<String, Map<String, String>> newInstanceState = new HashMap<>();
+    Map<String, String> newShardState = new HashMap<>();
+    newShardState.put("shard1", "OFFLINE");
+    newInstanceState.put("resource1", newShardState);
+    newState.put("instance1", newInstanceState);
 
-    cache.updateTargetStateWithDiff("instance1", targetStateChange);
+    Map<String, Map<String, Map<String, String>>> diff = 
cache.updateCacheWithNewCurrentStateAndGetDiff(newState);
 
-    Assert.assertEquals(cache.getTargetState("instance1", "resource1", 
"shard1"), "OFFLINE");
-    
Assert.assertEquals(cache.serializeTargetAssignmentsToJSONNode().toString(), 
"{\"instance1\":{\"resource1\":{\"shard1\":\"OFFLINE\"}}}");
+    Assert.assertNotNull(diff);
+    Assert.assertEquals(diff.size(), 1);
+    Assert.assertEquals(diff.get("instance1").get("resource1").get("shard1"), 
"OFFLINE");
   }
 }


Reply via email to