This is an automated email from the ASF dual-hosted git repository. xyuanlu pushed a commit to branch helix-gateway-service in repository https://gitbox.apache.org/repos/asf/helix.git
commit fd092b51b985fd3c5d9b8af12cab132dc6a38eb4 Author: xyuanlu <xyua...@gmail.com> AuthorDate: Thu Sep 12 20:50:11 2024 -0700 Gateway - gateway participant update target state in cache (#2910) Gateway - gateway participant update target state in cache --- .../channel/HelixGatewayServiceGrpcService.java | 13 +-- .../participant/HelixGatewayParticipant.java | 79 +++++-------- .../gateway/service/GatewayServiceManager.java | 48 +++++--- .../gateway/util/GatewayCurrentStateCache.java | 126 +++++++++++---------- .../util/StateTransitionMessageTranslateUtil.java | 12 +- .../participant/TestHelixGatewayParticipant.java | 75 ++++++++++-- .../gateway/util/TestGatewayCurrentStateCache.java | 35 +++--- 7 files changed, 226 insertions(+), 162 deletions(-) diff --git a/helix-gateway/src/main/java/org/apache/helix/gateway/channel/HelixGatewayServiceGrpcService.java b/helix-gateway/src/main/java/org/apache/helix/gateway/channel/HelixGatewayServiceGrpcService.java index a9c1ca4a2..6a2c76b8a 100644 --- a/helix-gateway/src/main/java/org/apache/helix/gateway/channel/HelixGatewayServiceGrpcService.java +++ b/helix-gateway/src/main/java/org/apache/helix/gateway/channel/HelixGatewayServiceGrpcService.java @@ -38,7 +38,6 @@ import org.apache.helix.gateway.util.StateTransitionMessageTranslateUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import proto.org.apache.helix.gateway.HelixGatewayServiceGrpc; -import proto.org.apache.helix.gateway.HelixGatewayServiceOuterClass; import proto.org.apache.helix.gateway.HelixGatewayServiceOuterClass.ShardChangeRequests; import proto.org.apache.helix.gateway.HelixGatewayServiceOuterClass.ShardState; import proto.org.apache.helix.gateway.HelixGatewayServiceOuterClass.ShardStateMessage; @@ -92,7 +91,7 @@ public class HelixGatewayServiceGrpcService extends HelixGatewayServiceGrpc.Heli updateObserver(shardState.getInstanceName(), shardState.getClusterName(), responseObserver); } pushClientEventToGatewayManager(_manager, - StateTransitionMessageTranslateUtil.translateShardStateMessageToEvent(request)); + StateTransitionMessageTranslateUtil.translateShardStateMessageToEventAndUpdateCache(_manager, request)); } @Override @@ -120,9 +119,8 @@ public class HelixGatewayServiceGrpcService extends HelixGatewayServiceGrpc.Heli */ @Override public void sendStateChangeRequests(String instanceName, ShardChangeRequests requests) { - StreamObserver<HelixGatewayServiceOuterClass.ShardChangeRequests> observer; - observer = _observerMap.get(instanceName); - if (observer != null) { + StreamObserver<ShardChangeRequests> observer = _observerMap.get(instanceName); + if (observer!= null) { observer.onNext(requests); } else { logger.error("Instance {} is not connected to the gateway service", instanceName); @@ -151,8 +149,7 @@ public class HelixGatewayServiceGrpcService extends HelixGatewayServiceGrpc.Heli } private void closeConnectionHelper(String instanceName, String errorReason, boolean withError) { - StreamObserver<ShardChangeRequests> observer; - observer = _observerMap.get(instanceName); + StreamObserver<ShardChangeRequests> observer = _observerMap.get(instanceName); if (observer != null) { if (withError) { observer.onError(Status.UNAVAILABLE.withDescription(errorReason).asRuntimeException()); @@ -162,7 +159,7 @@ public class HelixGatewayServiceGrpcService extends HelixGatewayServiceGrpc.Heli } } - public void onClientClose(String clusterName, String instanceName) { + private void onClientClose(String clusterName, String instanceName) { if (instanceName == null || clusterName == null) { // TODO: log error; return; diff --git a/helix-gateway/src/main/java/org/apache/helix/gateway/participant/HelixGatewayParticipant.java b/helix-gateway/src/main/java/org/apache/helix/gateway/participant/HelixGatewayParticipant.java index b17d897e1..8dd04644b 100644 --- a/helix-gateway/src/main/java/org/apache/helix/gateway/participant/HelixGatewayParticipant.java +++ b/helix-gateway/src/main/java/org/apache/helix/gateway/participant/HelixGatewayParticipant.java @@ -19,17 +19,15 @@ package org.apache.helix.gateway.participant; * under the License. */ -import com.google.common.annotations.VisibleForTesting; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; -import org.apache.helix.HelixDefinedState; import org.apache.helix.HelixManager; import org.apache.helix.InstanceType; import org.apache.helix.gateway.api.service.HelixGatewayServiceChannel; +import org.apache.helix.gateway.service.GatewayServiceManager; import org.apache.helix.gateway.statemodel.HelixGatewayMultiTopStateStateModelFactory; import org.apache.helix.gateway.util.StateTransitionMessageTranslateUtil; import org.apache.helix.manager.zk.HelixManagerStateListener; @@ -37,6 +35,7 @@ import org.apache.helix.manager.zk.ZKHelixManager; import org.apache.helix.model.Message; import org.apache.helix.participant.statemachine.StateTransitionError; + /** * HelixGatewayParticipant encapsulates the Helix Participant Manager and handles tracking the state * of a remote participant connected to the Helix Gateway Service. It processes state transitions @@ -48,18 +47,18 @@ public class HelixGatewayParticipant implements HelixManagerStateListener { private final HelixGatewayServiceChannel _gatewayServiceChannel; private final HelixManager _helixManager; private final Runnable _onDisconnectedCallback; - private final Map<String, Map<String, String>> _shardStateMap; - private final Map<String, CompletableFuture<String>> _stateTransitionResultMap; - private HelixGatewayParticipant(HelixGatewayServiceChannel gatewayServiceChannel, - Runnable onDisconnectedCallback, HelixManager helixManager, - Map<String, Map<String, String>> initialShardStateMap) { + private final GatewayServiceManager _gatewayServiceManager; + + private HelixGatewayParticipant(HelixGatewayServiceChannel gatewayServiceChannel, Runnable onDisconnectedCallback, + HelixManager helixManager, Map<String, Map<String, String>> initialShardStateMap, + GatewayServiceManager gatewayServiceManager) { _gatewayServiceChannel = gatewayServiceChannel; _helixManager = helixManager; _onDisconnectedCallback = onDisconnectedCallback; - _shardStateMap = initialShardStateMap; _stateTransitionResultMap = new ConcurrentHashMap<>(); + _gatewayServiceManager = gatewayServiceManager; } public void processStateTransitionMessage(Message message) throws Exception { @@ -69,12 +68,14 @@ public class HelixGatewayParticipant implements HelixManagerStateListener { String concatenatedShardName = resourceId + shardId; try { + // update the target state in cache + _gatewayServiceManager.updateTargetState(_helixManager.getClusterName(), _helixManager.getInstanceName(), + resourceId, shardId, toState); + if (isCurrentStateAlreadyTarget(resourceId, shardId, toState)) { return; } - CompletableFuture<String> future = new CompletableFuture<>(); - _stateTransitionResultMap.put(concatenatedShardName, future); _gatewayServiceChannel.sendStateChangeRequests(_helixManager.getInstanceName(), StateTransitionMessageTranslateUtil.translateSTMsgToShardChangeRequests(message)); @@ -82,8 +83,6 @@ public class HelixGatewayParticipant implements HelixManagerStateListener { if (!toState.equals(future.get())) { throw new Exception("Failed to transition to state " + toState); } - - updateState(resourceId, shardId, toState); } finally { _stateTransitionResultMap.remove(concatenatedShardName); } @@ -92,24 +91,18 @@ public class HelixGatewayParticipant implements HelixManagerStateListener { public void handleStateTransitionError(Message message, StateTransitionError error) { // Remove the stateTransitionResultMap future for the message String transitionId = message.getMsgId(); - String resourceId = message.getResourceName(); - String shardId = message.getPartitionName(); // Remove the future from the stateTransitionResultMap since we are no longer able // to process the state transition due to participant manager either timing out // or failing to process the state transition _stateTransitionResultMap.remove(transitionId); - // Set the replica state to ERROR - updateState(resourceId, shardId, HelixDefinedState.ERROR.name()); - // Notify the HelixGatewayParticipantClient that it is in ERROR state // TODO: We need a better way than sending the state transition with a toState of ERROR } /** * Get the instance name of the participant. - * * @return participant instance name */ public String getInstanceName() { @@ -118,7 +111,6 @@ public class HelixGatewayParticipant implements HelixManagerStateListener { /** * Completes the state transition with the given transitionId. - * */ public void completeStateTransition(String resourceId, String shardId, String currentState) { String concatenatedShardName = resourceId + shardId; @@ -128,16 +120,10 @@ public class HelixGatewayParticipant implements HelixManagerStateListener { } } - private boolean isCurrentStateAlreadyTarget(String resourceId, String shardId, - String targetState) { + private boolean isCurrentStateAlreadyTarget(String resourceId, String shardId, String targetState) { return getCurrentState(resourceId, shardId).equals(targetState); } - @VisibleForTesting - Map<String, Map<String, String>> getShardStateMap() { - return _shardStateMap; - } - /** * Get the current state of the shard. * @@ -146,23 +132,10 @@ public class HelixGatewayParticipant implements HelixManagerStateListener { * @return the current state of the shard or DROPPED if it does not exist */ public String getCurrentState(String resourceId, String shardId) { - return getShardStateMap().getOrDefault(resourceId, Collections.emptyMap()) - .getOrDefault(shardId, UNASSIGNED_STATE); - } - - private void updateState(String resourceId, String shardId, String state) { - if (state.equals(HelixDefinedState.DROPPED.name())) { - getShardStateMap().computeIfPresent(resourceId, (k, v) -> { - v.remove(shardId); - if (v.isEmpty()) { - return null; - } - return v; - }); - } else { - getShardStateMap().computeIfAbsent(resourceId, k -> new ConcurrentHashMap<>()) - .put(shardId, state); - } + String currentState = + _gatewayServiceManager.getCurrentState(_helixManager.getClusterName(), _helixManager.getInstanceName(), + resourceId, shardId); + return currentState == null ? UNASSIGNED_STATE : currentState; } /** @@ -203,13 +176,15 @@ public class HelixGatewayParticipant implements HelixManagerStateListener { private final Runnable _onDisconnectedCallback; private final List<String> _multiTopStateModelDefinitions; private final Map<String, Map<String, String>> _initialShardStateMap; + private final GatewayServiceManager _gatewayServiceManager; - public Builder(HelixGatewayServiceChannel helixGatewayServiceChannel, String instanceName, - String clusterName, String zkAddress, Runnable onDisconnectedCallback) { + public Builder(HelixGatewayServiceChannel helixGatewayServiceChannel, String instanceName, String clusterName, + String zkAddress, Runnable onDisconnectedCallback, GatewayServiceManager gatewayServiceManager) { _helixGatewayServiceChannel = helixGatewayServiceChannel; _instanceName = instanceName; _clusterName = clusterName; _zkAddress = zkAddress; + _gatewayServiceManager = gatewayServiceManager; _onDisconnectedCallback = onDisconnectedCallback; _multiTopStateModelDefinitions = new ArrayList<>(); _initialShardStateMap = new ConcurrentHashMap<>(); @@ -259,13 +234,11 @@ public class HelixGatewayParticipant implements HelixManagerStateListener { HelixManager participantManager = new ZKHelixManager(_clusterName, _instanceName, InstanceType.PARTICIPANT, _zkAddress); HelixGatewayParticipant participant = - new HelixGatewayParticipant(_helixGatewayServiceChannel, _onDisconnectedCallback, - participantManager, - _initialShardStateMap); - _multiTopStateModelDefinitions.forEach( - stateModelDefinition -> participantManager.getStateMachineEngine() - .registerStateModelFactory(stateModelDefinition, - new HelixGatewayMultiTopStateStateModelFactory(participant))); + new HelixGatewayParticipant(_helixGatewayServiceChannel, _onDisconnectedCallback, participantManager, + _initialShardStateMap, _gatewayServiceManager); + _multiTopStateModelDefinitions.forEach(stateModelDefinition -> participantManager.getStateMachineEngine() + .registerStateModelFactory(stateModelDefinition, + new HelixGatewayMultiTopStateStateModelFactory(participant))); try { participantManager.connect(); } catch (Exception e) { diff --git a/helix-gateway/src/main/java/org/apache/helix/gateway/service/GatewayServiceManager.java b/helix-gateway/src/main/java/org/apache/helix/gateway/service/GatewayServiceManager.java index 94f5783f0..e4d207fd7 100644 --- a/helix-gateway/src/main/java/org/apache/helix/gateway/service/GatewayServiceManager.java +++ b/helix-gateway/src/main/java/org/apache/helix/gateway/service/GatewayServiceManager.java @@ -108,29 +108,48 @@ public class GatewayServiceManager { } } - private GatewayCurrentStateCache getCache(String clusterName) { - return _currentStateCacheMap.computeIfAbsent(clusterName, k -> new GatewayCurrentStateCache(clusterName)); - } - public void resetTargetStateCache(String clusterName, String instanceName) { - getCache(clusterName).resetTargetStateCache(instanceName); + getOrCreateCache(clusterName).resetTargetStateCache(instanceName); } + /** + * Overwrite the current state cache with the new current state map, and return the diff of the change. + * @param clusterName + * @param newCurrentStateMap + * @return + */ public Map<String, Map<String, Map<String, String>>> updateCacheWithNewCurrentStateAndGetDiff(String clusterName, Map<String, Map<String, Map<String, String>>> newCurrentStateMap) { - return getCache(clusterName).updateCacheWithNewCurrentStateAndGetDiff(newCurrentStateMap); + return getOrCreateCache(clusterName).updateCacheWithNewCurrentStateAndGetDiff(newCurrentStateMap); + } + + public void updateCurrentState(String clusterName, String instanceName, String resourceId, String shardId, String toState) { + getOrCreateCache(clusterName).updateCurrentStateOfExistingInstance(instanceName, resourceId, shardId, toState); } - public String serializeTargetState() { + public synchronized String serializeTargetState() { ObjectNode targetStateNode = new ObjectMapper().createObjectNode(); for (String clusterName : _currentStateCacheMap.keySet()) { // add the json node to the target state node - targetStateNode.set(clusterName, getCache(clusterName).serializeTargetAssignmentsToJSONNode()); + targetStateNode.set(clusterName, getOrCreateCache(clusterName).serializeTargetAssignmentsToJSONNode()); } targetStateNode.set("timestamp", objectMapper.valueToTree(System.currentTimeMillis())); return targetStateNode.toString(); } + public void updateTargetState(String clusterName, String instanceName, String resourceId, String shardId, + String toState) { + getOrCreateCache(clusterName).updateTargetStateOfExistingInstance(instanceName, resourceId, shardId, toState); + } + + public String getCurrentState(String clusterName, String instanceName, String resourceId, String shardId) { + return getOrCreateCache(clusterName).getCurrentState(instanceName, resourceId, shardId); + } + + public String getTargetState(String clusterName, String instanceName, String resourceId, String shardId) { + return getOrCreateCache(clusterName).getTargetState(instanceName, resourceId, shardId); + } + /** * Update in memory shard state */ @@ -199,12 +218,10 @@ public class GatewayServiceManager { resetTargetStateCache(clusterName, instanceName); // Create and add the participant to the participant map HelixGatewayParticipant.Builder participantBuilder = - new HelixGatewayParticipant.Builder(_gatewayServiceChannel, instanceName, clusterName, - _zkAddress, - () -> removeHelixGatewayParticipant(clusterName, instanceName)).setInitialShardState( + new HelixGatewayParticipant.Builder(_gatewayServiceChannel, instanceName, clusterName, _zkAddress, + () -> removeHelixGatewayParticipant(clusterName, instanceName), this).setInitialShardState( initialShardStateMap); - SUPPORTED_MULTI_STATE_MODEL_TYPES.forEach( - participantBuilder::addMultiTopStateStateModelDefinition); + SUPPORTED_MULTI_STATE_MODEL_TYPES.forEach(participantBuilder::addMultiTopStateStateModelDefinition); _helixGatewayParticipantMap.computeIfAbsent(clusterName, k -> new ConcurrentHashMap<>()) .put(instanceName, participantBuilder.build()); } @@ -216,6 +233,7 @@ public class GatewayServiceManager { participant.disconnect(); _helixGatewayParticipantMap.get(clusterName).remove(instanceName); } + _currentStateCacheMap.get(clusterName).removeInstanceTargetDataFromCache(instanceName); } private HelixGatewayParticipant getHelixGatewayParticipant(String clusterName, @@ -223,4 +241,8 @@ public class GatewayServiceManager { return _helixGatewayParticipantMap.getOrDefault(clusterName, Collections.emptyMap()) .get(instanceName); } + + private synchronized GatewayCurrentStateCache getOrCreateCache(String clusterName) { + return _currentStateCacheMap.computeIfAbsent(clusterName, k -> new GatewayCurrentStateCache(clusterName)); + } } diff --git a/helix-gateway/src/main/java/org/apache/helix/gateway/util/GatewayCurrentStateCache.java b/helix-gateway/src/main/java/org/apache/helix/gateway/util/GatewayCurrentStateCache.java index a3ba7fbe7..2b8b1c978 100644 --- a/helix-gateway/src/main/java/org/apache/helix/gateway/util/GatewayCurrentStateCache.java +++ b/helix-gateway/src/main/java/org/apache/helix/gateway/util/GatewayCurrentStateCache.java @@ -23,14 +23,17 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * A cache to store the current target assignment, and the reported current state of the instances in a cluster. */ public class GatewayCurrentStateCache { - static ObjectMapper mapper = new ObjectMapper(); - String _clusterName; + private static final Logger logger = LoggerFactory.getLogger(GatewayCurrentStateCache.class); + private static final ObjectMapper mapper = new ObjectMapper(); // A cache of current state. It should be updated by the HelixGatewayServiceChannel // instance -> resource state (resource -> shard -> target state) @@ -38,120 +41,123 @@ public class GatewayCurrentStateCache { // A cache of target state. // instance -> resource state (resource -> shard -> target state) - Map<String, ShardStateMap> _targetStateMap; + final Map<String, ShardStateMap> _targetStateMap; + ObjectNode root = mapper.createObjectNode(); public GatewayCurrentStateCache(String clusterName) { - _clusterName = clusterName; _currentStateMap = new HashMap<>(); _targetStateMap = new HashMap<>(); } public String getCurrentState(String instance, String resource, String shard) { - return _currentStateMap.get(instance).getState(resource, shard); + ShardStateMap shardStateMap = _currentStateMap.get(instance); + return shardStateMap == null ? null : shardStateMap.getState(resource, shard); } public String getTargetState(String instance, String resource, String shard) { - return _targetStateMap.get(instance).getState(resource, shard); + ShardStateMap shardStateMap = _targetStateMap.get(instance); + return shardStateMap == null ? null : shardStateMap.getState(resource, shard); } /** * Update the cached current state of instances in a cluster, and return the diff of the change. - * @param newCurrentStateMap The new current state map of instances in the cluster + * @param userCurrentStateMap The new current state map of instances in the cluster * @return */ public Map<String, Map<String, Map<String, String>>> updateCacheWithNewCurrentStateAndGetDiff( - Map<String, Map<String, Map<String, String>>> newCurrentStateMap) { + Map<String, Map<String, Map<String, String>>> userCurrentStateMap) { + Map<String, ShardStateMap> newCurrentStateMap = new HashMap<>(_currentStateMap); Map<String, Map<String, Map<String, String>>> diff = new HashMap<>(); - for (String instance : newCurrentStateMap.keySet()) { - Map<String, Map<String, String>> newCurrentState = newCurrentStateMap.get(instance); - Map<String, Map<String, String>> resourceStateDiff = - _currentStateMap.computeIfAbsent(instance, k -> new ShardStateMap(new HashMap<>())) - .updateAndGetDiff(newCurrentState); - if (resourceStateDiff != null && !resourceStateDiff.isEmpty()) { - diff.put(instance, resourceStateDiff); + for (String instance : userCurrentStateMap.keySet()) { + ShardStateMap oldStateMap = _currentStateMap.get(instance); + Map<String, Map<String, String>> instanceDiff = oldStateMap == null ? userCurrentStateMap.get(instance) + : oldStateMap.getDiff(userCurrentStateMap.get(instance)); + if (!instanceDiff.isEmpty()) { + diff.put(instance, instanceDiff); } + newCurrentStateMap.put(instance, new ShardStateMap(userCurrentStateMap.get(instance))); } + logger.info("Update current state cache for instances: {}", diff.keySet()); + _currentStateMap = newCurrentStateMap; return diff; } /** - * Update the cache with the current state diff. - * All existing target states remains the same - * @param currentStateDiff + * Update the current state with the changed current state maps. */ - public void updateCacheWithCurrentStateDiff(Map<String, Map<String, Map<String, String>>> currentStateDiff) { - for (String instance : currentStateDiff.keySet()) { - Map<String, Map<String, String>> currentStateDiffMap = currentStateDiff.get(instance); - updateShardStateMapWithDiff(_currentStateMap, instance, currentStateDiffMap); - } + public void updateCurrentStateOfExistingInstance(String instance, String resource, String shard, String shardState) { + logger.info("Update current state of instance: {}, resource: {}, shard: {}, state: {}", instance, resource, shard, + shardState); + updateShardStateMapWithDiff(_currentStateMap, instance, resource, shard, shardState); } /** * Update the target state with the changed target state maps. * All existing target states remains the same - * @param targetStateChangeMap */ - public void updateTargetStateWithDiff(String instance, Map<String, Map<String, String>> targetStateChangeMap) { - updateShardStateMapWithDiff(_targetStateMap, instance, targetStateChangeMap); + public void updateTargetStateOfExistingInstance(String instance, String resource, String shard, String shardState) { + logger.info("Update target state of instance: {}, resource: {}, shard: {}, state: {}", instance, resource, shard, + shardState); + updateShardStateMapWithDiff(_targetStateMap, instance, resource, shard, shardState); + } + + private void updateShardStateMapWithDiff(Map<String, ShardStateMap> stateMap, String instance, String resource, + String shard, String shardState) { + ShardStateMap curStateMap = stateMap.get(instance); + if (curStateMap == null) { + logger.warn("Instance {} is not in the state map, skip updating", instance); + return; + } + curStateMap.updateWithShardState(resource, shard, shardState); } /** * Serialize the target state assignments to a JSON Node. * example : {"instance1":{"resource1":{"shard1":"ONLINE","shard2":"OFFLINE"}}}} */ - public ObjectNode serializeTargetAssignmentsToJSONNode() { - ObjectNode root = mapper.createObjectNode(); + public synchronized ObjectNode serializeTargetAssignmentsToJSONNode() { for (Map.Entry<String, ShardStateMap> entry : _targetStateMap.entrySet()) { root.set(entry.getKey(), entry.getValue().toJSONNode()); } return root; } - private void updateShardStateMapWithDiff(Map<String, ShardStateMap> stateMap, String instance, - Map<String, Map<String, String>> diffMap) { - if (diffMap == null || diffMap.isEmpty()) { - return; - } - stateMap.computeIfAbsent(instance, k -> new ShardStateMap(new HashMap<>())).updateWithDiff(diffMap); + /** + * Remove the target state data of an instance from the cache. + */ + public synchronized void removeInstanceTargetDataFromCache(String instance) { + logger.info("Remove instance target data from cache for instance: {}", instance); + _targetStateMap.remove(instance); + root.remove(instance); } - public void resetTargetStateCache(String instance) { + /** + * Remove the current state data of an instance from the cache to an empty map. + */ + public synchronized void resetTargetStateCache(String instance) { + logger.info("Reset target state cache for instance: {}", instance); _targetStateMap.put(instance, new ShardStateMap(new HashMap<>())); } public static class ShardStateMap { Map<String, Map<String, String>> _stateMap; + ObjectNode root = mapper.createObjectNode(); public ShardStateMap(Map<String, Map<String, String>> stateMap) { - _stateMap = stateMap; + _stateMap = new HashMap<>(stateMap); } - public String getState(String instance, String shard) { - return _stateMap.get(instance).get(shard); + public String getState(String resource, String shard) { + Map<String, String> shardStateMap = _stateMap.get(resource); + return shardStateMap == null ? null : shardStateMap.get(shard); } - private Map<String, Map<String, String>> getShardStateMap() { - return _stateMap; - } - - private void updateWithDiff(Map<String, Map<String, String>> diffMap) { - for (Map.Entry<String, Map<String, String>> diffEntry : diffMap.entrySet()) { - String resource = diffEntry.getKey(); - Map<String, String> diffCurrentState = diffEntry.getValue(); - if (_stateMap.get(resource) != null) { - _stateMap.get(resource).entrySet().forEach(currentMapEntry -> { - String shard = currentMapEntry.getKey(); - if (diffCurrentState.get(shard) != null) { - currentMapEntry.setValue(diffCurrentState.get(shard)); - } - }); - } else { - _stateMap.put(resource, diffCurrentState); - } - } + public synchronized void updateWithShardState(String resource, String shard, String shardState) { + logger.info("Update ShardStateMap of resource: {}, shard: {}, state: {}", resource, shard, shardState); + _stateMap.computeIfAbsent(resource, k -> new HashMap<>()).put(shard, shardState); } - private Map<String, Map<String, String>> updateAndGetDiff(Map<String, Map<String, String>> newCurrentStateMap) { + private Map<String, Map<String, String>> getDiff(Map<String, Map<String, String>> newCurrentStateMap) { Map<String, Map<String, String>> diff = new HashMap<>(); for (Map.Entry<String, Map<String, String>> entry : newCurrentStateMap.entrySet()) { String resource = entry.getKey(); @@ -169,7 +175,6 @@ public class GatewayCurrentStateCache { } } } - _stateMap = newCurrentStateMap; return diff; } @@ -177,8 +182,7 @@ public class GatewayCurrentStateCache { * Serialize the shard state map to a JSON object. * @return a JSON object representing the shard state map. Example: {"shard1":"ONLINE","shard2":"OFFLINE"} */ - public ObjectNode toJSONNode() { - ObjectNode root = mapper.createObjectNode(); + public synchronized ObjectNode toJSONNode() { for (Map.Entry<String, Map<String, String>> entry : _stateMap.entrySet()) { String resource = entry.getKey(); ObjectNode resourceNode = mapper.createObjectNode(); diff --git a/helix-gateway/src/main/java/org/apache/helix/gateway/util/StateTransitionMessageTranslateUtil.java b/helix-gateway/src/main/java/org/apache/helix/gateway/util/StateTransitionMessageTranslateUtil.java index f9f2b4c3d..a2d07085b 100644 --- a/helix-gateway/src/main/java/org/apache/helix/gateway/util/StateTransitionMessageTranslateUtil.java +++ b/helix-gateway/src/main/java/org/apache/helix/gateway/util/StateTransitionMessageTranslateUtil.java @@ -27,6 +27,7 @@ import org.apache.helix.HelixDefinedState; import org.apache.helix.gateway.api.constant.GatewayServiceEventType; import org.apache.helix.gateway.participant.HelixGatewayParticipant; import org.apache.helix.gateway.service.GatewayServiceEvent; +import org.apache.helix.gateway.service.GatewayServiceManager; import org.apache.helix.model.Message; import proto.org.apache.helix.gateway.HelixGatewayServiceOuterClass; import proto.org.apache.helix.gateway.HelixGatewayServiceOuterClass.ShardChangeRequests; @@ -79,7 +80,8 @@ public final class StateTransitionMessageTranslateUtil { * contains the state of each shard upon connection or result of state transition request. * @return GatewayServiceEvent */ - public static GatewayServiceEvent translateShardStateMessageToEvent(ShardStateMessage request) { + public static GatewayServiceEvent translateShardStateMessageToEventAndUpdateCache( + GatewayServiceManager manager, ShardStateMessage request) { GatewayServiceEvent.GateWayServiceEventBuilder builder; if (request.hasShardState()) { // init connection to gateway service ShardState shardState = request.getShardState(); @@ -90,6 +92,11 @@ public final class StateTransitionMessageTranslateUtil { .put(state.getShardName(), state.getCurrentState()); } } + // update current state cache. We always overwrite the current state map for initial connection + Map<String, Map<String, Map<String, String>>> newShardStateMap = new HashMap<>(); + newShardStateMap.put(shardState.getInstanceName(), shardStateMap); + manager.updateCacheWithNewCurrentStateAndGetDiff(shardState.getClusterName(), newShardStateMap); + builder = new GatewayServiceEvent.GateWayServiceEventBuilder(GatewayServiceEventType.CONNECT).setClusterName( shardState.getClusterName()).setParticipantName(shardState.getInstanceName()) .setShardStateMap(shardStateMap); @@ -103,6 +110,9 @@ public final class StateTransitionMessageTranslateUtil { GatewayServiceEvent.StateTransitionResult result = new GatewayServiceEvent.StateTransitionResult(shardTransition.getResourceName(), shardTransition.getShardName(), shardTransition.getCurrentState()); + // update current state cache + manager.updateCurrentState(shardTransitionStatus.getClusterName(), shardTransitionStatus.getInstanceName(), + shardTransition.getResourceName(), shardTransition.getShardName(), shardTransition.getCurrentState()); stResult.add(result); } builder = new GatewayServiceEvent.GateWayServiceEventBuilder(GatewayServiceEventType.UPDATE).setClusterName( diff --git a/helix-gateway/src/test/java/org/apache/helix/gateway/participant/TestHelixGatewayParticipant.java b/helix-gateway/src/test/java/org/apache/helix/gateway/participant/TestHelixGatewayParticipant.java index 75e9fc581..128a22857 100644 --- a/helix-gateway/src/test/java/org/apache/helix/gateway/participant/TestHelixGatewayParticipant.java +++ b/helix-gateway/src/test/java/org/apache/helix/gateway/participant/TestHelixGatewayParticipant.java @@ -21,6 +21,7 @@ package org.apache.helix.gateway.participant; import java.io.IOException; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -32,6 +33,8 @@ import org.apache.helix.ConfigAccessor; import org.apache.helix.TestHelper; import org.apache.helix.common.ZkTestBase; import org.apache.helix.gateway.api.service.HelixGatewayServiceChannel; +import org.apache.helix.gateway.channel.GatewayServiceChannelConfig; +import org.apache.helix.gateway.service.GatewayServiceManager; import org.apache.helix.integration.manager.ClusterControllerManager; import org.apache.helix.manager.zk.ZKHelixManager; import org.apache.helix.model.ClusterConfig; @@ -62,8 +65,17 @@ public class TestHelixGatewayParticipant extends ZkTestBase { private final Map<String, ShardChangeRequests> _pendingMessageMap = new ConcurrentHashMap<>(); private final AtomicInteger _onDisconnectCallbackCount = new AtomicInteger(); + private GatewayServiceManager _gatewayServiceManager; + @BeforeClass public void beforeClass() { + GatewayServiceChannelConfig.GatewayServiceProcessorConfigBuilder builder = + new GatewayServiceChannelConfig.GatewayServiceProcessorConfigBuilder(); + + builder.setParticipantConnectionChannelType(GatewayServiceChannelConfig.ChannelType.GRPC_SERVER).setGrpcServerPort(5001); + GatewayServiceChannelConfig config = builder.build(); + _gatewayServiceManager = new GatewayServiceManager(ZK_ADDR, config); + // Set up the Helix cluster ConfigAccessor configAccessor = new ConfigAccessor(_gZkClient); _gSetupTool.addCluster(CLUSTER_NAME, true); @@ -100,9 +112,10 @@ public class TestHelixGatewayParticipant extends ZkTestBase { */ private HelixGatewayParticipant addParticipant(String participantName, Map<String, Map<String, String>> initialShardMap) { + _gatewayServiceManager.resetTargetStateCache(CLUSTER_NAME, participantName); HelixGatewayParticipant participant = new HelixGatewayParticipant.Builder(new MockHelixGatewayServiceChannel(_pendingMessageMap), participantName, - CLUSTER_NAME, ZK_ADDR, _onDisconnectCallbackCount::incrementAndGet).addMultiTopStateStateModelDefinition( + CLUSTER_NAME, ZK_ADDR, _onDisconnectCallbackCount::incrementAndGet, _gatewayServiceManager).addMultiTopStateStateModelDefinition( TEST_STATE_MODEL).setInitialShardState(initialShardMap).build(); _participants.add(participant); return participant; @@ -175,8 +188,26 @@ public class TestHelixGatewayParticipant extends ZkTestBase { private void processPendingMessage(HelixGatewayParticipant participant, boolean isSuccess, String toState) { ShardChangeRequests requests = _pendingMessageMap.remove(participant.getInstanceName()); - participant.completeStateTransition(requests.getRequest(0).getResourceName(),requests.getRequest(0).getShardName(), - isSuccess ? toState : "WRONG_STATE"); + Map<String, Map<String, Map<String, String>>> newSInstanceStateMap = new HashMap<>(); + newSInstanceStateMap.put(participant.getInstanceName(), + createSingleShardStateMap(requests.getRequest(0).getResourceName(), requests.getRequest(0).getShardName(), + isSuccess ? toState : "ERROR")); + _gatewayServiceManager.updateCacheWithNewCurrentStateAndGetDiff(CLUSTER_NAME, newSInstanceStateMap); + + participant.completeStateTransition(requests.getRequest(0).getResourceName(), requests.getRequest(0).getShardName(), + isSuccess ? toState : "ERROR"); + } + + /** + * Create a single shard state map. + */ + Map<String, Map<String, String>> createSingleShardStateMap( String resource, String shard, String state) { + + Map<String, Map<String, String>> resourceStateMap = new HashMap<>(); + Map<String, String> shardStateMap = new HashMap<>(); + shardStateMap.put(shard, state); + resourceStateMap.put(resource, shardStateMap); + return resourceStateMap; } /** @@ -209,7 +240,9 @@ public class TestHelixGatewayParticipant extends ZkTestBase { .getResourceIdealState(CLUSTER_NAME, resourceName) .getPartitionSet()) { String helixCurrentState = getHelixCurrentState(instanceName, resourceName, shardId); - if (!participant.getCurrentState(resourceName, shardId).equals(helixCurrentState)) { + if (!participant.getCurrentState(resourceName, shardId).equals(helixCurrentState) && !( + participant.getCurrentState(resourceName, shardId).equals("DROPPED") && helixCurrentState.equals( + "UNASSIGNED"))) { return false; } } @@ -218,6 +251,28 @@ public class TestHelixGatewayParticipant extends ZkTestBase { }), TestHelper.WAIT_DURATION)); } + private void verifyGatewayTargetStateMatchHelixTargetState() throws Exception { + Assert.assertTrue(TestHelper.verify(() -> _participants.stream().allMatch(participant -> { + String instanceName = participant.getInstanceName(); + for (String resourceName : _gSetupTool.getClusterManagementTool().getResourcesInCluster(CLUSTER_NAME)) { + for (String shardId : _gSetupTool.getClusterManagementTool() + .getResourceIdealState(CLUSTER_NAME, resourceName) + .getPartitionSet()) { + String helixTargetState = getHelixCurrentState(instanceName, resourceName, shardId); + if (_gatewayServiceManager.getTargetState(CLUSTER_NAME, instanceName, resourceName, shardId) == null) { + System.out.println("Gateway target state is null for instance: " + instanceName + ", resource: " + resourceName + ", shard: " + shardId); + } + if (!participant.getCurrentState(resourceName, shardId).equals(helixTargetState) && !( + participant.getCurrentState(resourceName, shardId).equals("DROPPED") && helixTargetState.equals( + "UNASSIGNED"))) { + return false; + } + } + } + return true; + }), 6000L)); + } + /** * Verify that all shards for a given instance are in a specific state. */ @@ -256,6 +311,7 @@ public class TestHelixGatewayParticipant extends ZkTestBase { // Verify that the cluster converges and all states are "ONLINE" Assert.assertTrue(_clusterVerifier.verify()); verifyGatewayStateMatchesHelixState(); + verifyGatewayTargetStateMatchHelixTargetState(); } @Test(dependsOnMethods = "testProcessStateTransitionMessageSuccess") @@ -293,7 +349,9 @@ public class TestHelixGatewayParticipant extends ZkTestBase { verifyHelixPartitionStates(participant.getInstanceName(), HelixGatewayParticipant.UNASSIGNED_STATE); // Re-add the participant with its initial state - addParticipant(participant.getInstanceName(), participant.getShardStateMap()); + addParticipant(participant.getInstanceName(), createSingleShardStateMap(TEST_DB, "TestDB_0", + _gatewayServiceManager.getCurrentState(CLUSTER_NAME, participant.getInstanceName(), TEST_DB, + "TestDB_0"))); Assert.assertTrue(_clusterVerifier.verify()); // Verify the Helix state is "ONLINE" @@ -309,8 +367,10 @@ public class TestHelixGatewayParticipant extends ZkTestBase { // Remove shard preference and re-add the participant removeFromPreferenceList(participant); - HelixGatewayParticipant participantReplacement = - addParticipant(participant.getInstanceName(), participant.getShardStateMap()); + HelixGatewayParticipant participantReplacement = addParticipant(participant.getInstanceName(), + createSingleShardStateMap(TEST_DB, "TestDB_0", + _gatewayServiceManager.getCurrentState(CLUSTER_NAME, participant.getInstanceName(), TEST_DB, + "TestDB_0"))); verifyPendingMessages(List.of(participantReplacement)); // Process the pending message successfully @@ -319,6 +379,7 @@ public class TestHelixGatewayParticipant extends ZkTestBase { // Verify that the cluster converges and states are correctly updated to "ONLINE" Assert.assertTrue(_clusterVerifier.verify()); verifyGatewayStateMatchesHelixState(); + verifyGatewayTargetStateMatchHelixTargetState(); } @Test(dependsOnMethods = "testProcessStateTransitionAfterReconnectAfterDroppingPartition") diff --git a/helix-gateway/src/test/java/org/apache/helix/gateway/util/TestGatewayCurrentStateCache.java b/helix-gateway/src/test/java/org/apache/helix/gateway/util/TestGatewayCurrentStateCache.java index 99fec4a25..17aa25637 100644 --- a/helix-gateway/src/test/java/org/apache/helix/gateway/util/TestGatewayCurrentStateCache.java +++ b/helix-gateway/src/test/java/org/apache/helix/gateway/util/TestGatewayCurrentStateCache.java @@ -51,31 +51,28 @@ public class TestGatewayCurrentStateCache { } @Test - public void testUpdateCacheWithCurrentStateDiff() { - Map<String, Map<String, Map<String, String>>> diff = new HashMap<>(); + public void testUpdateCacheWithExistingStateAndGetDiff() { + // Initial state + Map<String, Map<String, Map<String, String>>> initialState = new HashMap<>(); Map<String, Map<String, String>> instanceState = new HashMap<>(); Map<String, String> shardState = new HashMap<>(); - shardState.put("shard2", "ONLINE"); shardState.put("shard1", "ONLINE"); instanceState.put("resource1", shardState); - diff.put("instance1", instanceState); + initialState.put("instance1", instanceState); + cache.updateCacheWithNewCurrentStateAndGetDiff(initialState); - cache.updateCacheWithCurrentStateDiff(diff); - - Assert.assertEquals(cache.getCurrentState("instance1", "resource1", "shard1"), "ONLINE"); - Assert.assertEquals(cache.getCurrentState("instance1", "resource1", "shard2"), "ONLINE"); - } - - @Test - public void testUpdateTargetStateWithDiff() { - Map<String, Map<String, String>> targetStateChange = new HashMap<>(); - Map<String, String> shardState = new HashMap<>(); - shardState.put("shard1", "OFFLINE"); - targetStateChange.put("resource1", shardState); + // New state with a change + Map<String, Map<String, Map<String, String>>> newState = new HashMap<>(); + Map<String, Map<String, String>> newInstanceState = new HashMap<>(); + Map<String, String> newShardState = new HashMap<>(); + newShardState.put("shard1", "OFFLINE"); + newInstanceState.put("resource1", newShardState); + newState.put("instance1", newInstanceState); - cache.updateTargetStateWithDiff("instance1", targetStateChange); + Map<String, Map<String, Map<String, String>>> diff = cache.updateCacheWithNewCurrentStateAndGetDiff(newState); - Assert.assertEquals(cache.getTargetState("instance1", "resource1", "shard1"), "OFFLINE"); - Assert.assertEquals(cache.serializeTargetAssignmentsToJSONNode().toString(), "{\"instance1\":{\"resource1\":{\"shard1\":\"OFFLINE\"}}}"); + Assert.assertNotNull(diff); + Assert.assertEquals(diff.size(), 1); + Assert.assertEquals(diff.get("instance1").get("resource1").get("shard1"), "OFFLINE"); } }