This is an automated email from the ASF dual-hosted git repository.

totalo pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new 026bbf7  Optimize traffic rule getComputeNodeInstanceIds logic (#15977)
026bbf7 is described below

commit 026bbf7f8e1952808e3b332295f51a91591e106b
Author: Zhengqiang Duan <[email protected]>
AuthorDate: Fri Mar 11 08:53:20 2022 +0800

    Optimize traffic rule getComputeNodeInstanceIds logic (#15977)
---
 .../infra/instance/InstanceContext.java            | 15 ++++++++------
 .../jdbc/core/connection/ConnectionManager.java    | 18 ++++++++--------
 .../core/connection/ConnectionManagerTest.java     | 11 ++--------
 .../traffic/spi/TrafficLoadBalanceAlgorithm.java   |  3 ++-
 .../RandomTrafficLoadBalanceAlgorithm.java         |  3 ++-
 .../RoundRobinTrafficLoadBalanceAlgorithm.java     |  3 ++-
 .../traffic/engine/TrafficEngine.java              | 22 ++++++--------------
 .../shardingsphere/traffic/rule/TrafficRule.java   |  3 ++-
 .../algorithm/engine/TrafficEngineTest.java        | 24 +++++++++-------------
 .../RandomTrafficLoadBalanceAlgorithmTest.java     |  3 ++-
 .../RoundRobinTrafficLoadBalanceAlgorithmTest.java |  7 ++++---
 .../traffic/rule/TrafficRuleTest.java              |  4 ++--
 .../fixture/TestTrafficLoadBalanceAlgorithm.java   |  3 ++-
 13 files changed, 53 insertions(+), 66 deletions(-)

diff --git 
a/shardingsphere-infra/shardingsphere-infra-common/src/main/java/org/apache/shardingsphere/infra/instance/InstanceContext.java
 
b/shardingsphere-infra/shardingsphere-infra-common/src/main/java/org/apache/shardingsphere/infra/instance/InstanceContext.java
index be3a20c..1e6b79d 100644
--- 
a/shardingsphere-infra/shardingsphere-infra-common/src/main/java/org/apache/shardingsphere/infra/instance/InstanceContext.java
+++ 
b/shardingsphere-infra/shardingsphere-infra-common/src/main/java/org/apache/shardingsphere/infra/instance/InstanceContext.java
@@ -18,7 +18,9 @@
 package org.apache.shardingsphere.infra.instance;
 
 import lombok.Getter;
+import org.apache.commons.collections4.CollectionUtils;
 import org.apache.shardingsphere.infra.config.mode.ModeConfiguration;
+import org.apache.shardingsphere.infra.instance.definition.InstanceId;
 import org.apache.shardingsphere.infra.instance.definition.InstanceType;
 import org.apache.shardingsphere.infra.instance.workerid.WorkerIdGenerator;
 import org.apache.shardingsphere.infra.state.StateContext;
@@ -27,6 +29,7 @@ import org.apache.shardingsphere.infra.state.StateType;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.LinkedList;
+import java.util.List;
 import java.util.Objects;
 import java.util.Optional;
 
@@ -141,13 +144,13 @@ public final class InstanceContext {
      * @param labels collection of contained label
      * @return compute node instances
      */
-    public Collection<ComputeNodeInstance> getComputeNodeInstances(final 
InstanceType instanceType, final Collection<String> labels) {
-        Collection<ComputeNodeInstance> result = new 
ArrayList<>(computeNodeInstances.size());
-        computeNodeInstances.forEach(each -> {
-            if (each.getInstanceDefinition().getInstanceType() == instanceType 
&& each.getLabels().stream().anyMatch(labels::contains)) {
-                result.add(each);
+    public List<InstanceId> getComputeNodeInstanceIds(final InstanceType 
instanceType, final Collection<String> labels) {
+        List<InstanceId> result = new ArrayList<>(computeNodeInstances.size());
+        for (ComputeNodeInstance each : computeNodeInstances) {
+            if (each.getInstanceDefinition().getInstanceType() == instanceType 
&& CollectionUtils.containsAny(labels, each.getLabels())) {
+                result.add(each.getInstanceDefinition().getInstanceId());
             }
-        });
+        }
         return result;
     }
 }
diff --git 
a/shardingsphere-jdbc/shardingsphere-jdbc-core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/connection/ConnectionManager.java
 
b/shardingsphere-jdbc/shardingsphere-jdbc-core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/connection/ConnectionManager.java
index 1efb82a..b1aca6d 100644
--- 
a/shardingsphere-jdbc/shardingsphere-jdbc-core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/connection/ConnectionManager.java
+++ 
b/shardingsphere-jdbc/shardingsphere-jdbc-core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/connection/ConnectionManager.java
@@ -29,7 +29,6 @@ import 
org.apache.shardingsphere.infra.datasource.pool.creator.DataSourcePoolCre
 import org.apache.shardingsphere.infra.datasource.props.DataSourceProperties;
 import 
org.apache.shardingsphere.infra.executor.sql.execute.engine.ConnectionMode;
 import 
org.apache.shardingsphere.infra.executor.sql.prepare.driver.jdbc.ExecutorJDBCConnectionManager;
-import org.apache.shardingsphere.infra.instance.ComputeNodeInstance;
 import org.apache.shardingsphere.infra.instance.definition.InstanceId;
 import org.apache.shardingsphere.infra.instance.definition.InstanceType;
 import org.apache.shardingsphere.infra.metadata.user.ShardingSphereUser;
@@ -91,32 +90,31 @@ public final class ConnectionManager implements 
ExecutorJDBCConnectionManager, A
         Preconditions.checkState(!dataSourcePropsMap.isEmpty(), "Can not get 
data source properties from meta data.");
         DataSourceProperties dataSourcePropsSample = 
dataSourcePropsMap.values().iterator().next();
         Collection<ShardingSphereUser> users = 
metaDataPersistService.get().getGlobalRuleService().loadUsers();
-        Collection<ComputeNodeInstance> instances = 
contextManager.getInstanceContext().getComputeNodeInstances(InstanceType.PROXY, 
trafficRule.get().getLabels());
-        return 
DataSourcePoolCreator.create(createDataSourcePropertiesMap(instances, users, 
dataSourcePropsSample, schema));
+        Collection<InstanceId> instanceIds = 
contextManager.getInstanceContext().getComputeNodeInstanceIds(InstanceType.PROXY,
 trafficRule.get().getLabels());
+        return 
DataSourcePoolCreator.create(createDataSourcePropertiesMap(instanceIds, users, 
dataSourcePropsSample, schema));
     }
     
-    private Map<String, DataSourceProperties> 
createDataSourcePropertiesMap(final Collection<ComputeNodeInstance> instances, 
final Collection<ShardingSphereUser> users,
+    private Map<String, DataSourceProperties> 
createDataSourcePropertiesMap(final Collection<InstanceId> instanceIds, final 
Collection<ShardingSphereUser> users,
                                                                             
final DataSourceProperties dataSourcePropsSample, final String schema) {
         Map<String, DataSourceProperties> result = new LinkedHashMap<>();
-        for (ComputeNodeInstance each : instances) {
-            result.put(each.getInstanceDefinition().getInstanceId().getId(), 
createDataSourceProperties(each, users, dataSourcePropsSample, schema));
+        for (InstanceId each : instanceIds) {
+            result.put(each.getId(), createDataSourceProperties(each, users, 
dataSourcePropsSample, schema));
         }
         return result;
     }
     
-    private DataSourceProperties createDataSourceProperties(final 
ComputeNodeInstance instance, final Collection<ShardingSphereUser> users,
+    private DataSourceProperties createDataSourceProperties(final InstanceId 
instanceId, final Collection<ShardingSphereUser> users,
                                                             final 
DataSourceProperties dataSourcePropsSample, final String schema) {
         Map<String, Object> props = 
dataSourcePropsSample.getAllLocalProperties();
-        props.put("jdbcUrl", createJdbcUrl(instance, schema, props));
+        props.put("jdbcUrl", createJdbcUrl(instanceId, schema, props));
         ShardingSphereUser user = users.iterator().next();
         props.put("username", user.getGrantee().getUsername());
         props.put("password", user.getPassword());
         return new DataSourceProperties(HikariDataSource.class.getName(), 
props);
     }
     
-    private String createJdbcUrl(final ComputeNodeInstance instance, final 
String schema, final Map<String, Object> props) {
+    private String createJdbcUrl(final InstanceId instanceId, final String 
schema, final Map<String, Object> props) {
         String jdbcUrl = String.valueOf(props.get("jdbcUrl"));
-        InstanceId instanceId = 
instance.getInstanceDefinition().getInstanceId();
         String jdbcUrlPrefix = jdbcUrl.substring(0, jdbcUrl.indexOf("//"));
         String jdbcUrlSuffix = jdbcUrl.contains("?") ? 
jdbcUrl.substring(jdbcUrl.indexOf("?")) : "";
         return String.format("%s//%s:%s/%s%s", jdbcUrlPrefix, 
instanceId.getIp(), instanceId.getUniqueSign(), schema, jdbcUrlSuffix);
diff --git 
a/shardingsphere-jdbc/shardingsphere-jdbc-core/src/test/java/org/apache/shardingsphere/driver/jdbc/core/connection/ConnectionManagerTest.java
 
b/shardingsphere-jdbc/shardingsphere-jdbc-core/src/test/java/org/apache/shardingsphere/driver/jdbc/core/connection/ConnectionManagerTest.java
index ff30fad1..b6fd8cf 100644
--- 
a/shardingsphere-jdbc/shardingsphere-jdbc-core/src/test/java/org/apache/shardingsphere/driver/jdbc/core/connection/ConnectionManagerTest.java
+++ 
b/shardingsphere-jdbc/shardingsphere-jdbc-core/src/test/java/org/apache/shardingsphere/driver/jdbc/core/connection/ConnectionManagerTest.java
@@ -23,8 +23,7 @@ import org.apache.shardingsphere.infra.database.DefaultSchema;
 import 
org.apache.shardingsphere.infra.datasource.pool.creator.DataSourcePoolCreator;
 import org.apache.shardingsphere.infra.datasource.props.DataSourceProperties;
 import 
org.apache.shardingsphere.infra.executor.sql.execute.engine.ConnectionMode;
-import org.apache.shardingsphere.infra.instance.ComputeNodeInstance;
-import org.apache.shardingsphere.infra.instance.definition.InstanceDefinition;
+import org.apache.shardingsphere.infra.instance.definition.InstanceId;
 import org.apache.shardingsphere.infra.instance.definition.InstanceType;
 import org.apache.shardingsphere.infra.metadata.user.ShardingSphereUser;
 import org.apache.shardingsphere.mode.manager.ContextManager;
@@ -94,7 +93,7 @@ public final class ConnectionManagerTest {
         
when(result.getMetaDataContexts().getMetaDataPersistService()).thenReturn(Optional.of(metaDataPersistService));
         
when(result.getMetaDataContexts().getGlobalRuleMetaData().findSingleRule(TransactionRule.class)).thenReturn(Optional.empty());
         
when(result.getMetaDataContexts().getGlobalRuleMetaData().findSingleRule(TrafficRule.class)).thenReturn(Optional.of(trafficRule));
-        
when(result.getInstanceContext().getComputeNodeInstances(InstanceType.PROXY, 
Arrays.asList("OLTP", 
"OLAP"))).thenReturn(Collections.singletonList(mockComputeNodeInstance()));
+        
when(result.getInstanceContext().getComputeNodeInstanceIds(InstanceType.PROXY, 
Arrays.asList("OLTP", "OLAP"))).thenReturn(Collections.singletonList(new 
InstanceId("127.0.0.1@3307")));
         dataSourcePoolCreator = mockStatic(DataSourcePoolCreator.class);
         Map<String, DataSource> trafficDataSourceMap = 
mockTrafficDataSourceMap();
         when(DataSourcePoolCreator.create((Map) 
any())).thenReturn(trafficDataSourceMap);
@@ -128,12 +127,6 @@ public final class ConnectionManagerTest {
         return result;
     }
     
-    private ComputeNodeInstance mockComputeNodeInstance() {
-        ComputeNodeInstance result = new ComputeNodeInstance(new 
InstanceDefinition(InstanceType.PROXY, "127.0.0.1@3307"));
-        result.setLabels(Collections.singletonList("OLTP"));
-        return result;
-    }
-    
     private TrafficRule mockTrafficRule() {
         TrafficRule result = mock(TrafficRule.class);
         when(result.getLabels()).thenReturn(Arrays.asList("OLTP", "OLAP"));
diff --git 
a/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-api/src/main/java/org/apache/shardingsphere/traffic/spi/TrafficLoadBalanceAlgorithm.java
 
b/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-api/src/main/java/org/apache/shardingsphere/traffic/spi/TrafficLoadBalanceAlgorithm.java
index f237b94..6d09a64 100644
--- 
a/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-api/src/main/java/org/apache/shardingsphere/traffic/spi/TrafficLoadBalanceAlgorithm.java
+++ 
b/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-api/src/main/java/org/apache/shardingsphere/traffic/spi/TrafficLoadBalanceAlgorithm.java
@@ -18,6 +18,7 @@
 package org.apache.shardingsphere.traffic.spi;
 
 import 
org.apache.shardingsphere.infra.config.algorithm.ShardingSphereAlgorithm;
+import org.apache.shardingsphere.infra.instance.definition.InstanceId;
 import org.apache.shardingsphere.spi.required.RequiredSPI;
 
 import java.util.List;
@@ -34,5 +35,5 @@ public interface TrafficLoadBalanceAlgorithm extends 
ShardingSphereAlgorithm, Re
      * @param instanceIds instance id collection
      * @return instance id
      */
-    String getInstanceId(String name, List<String> instanceIds);
+    InstanceId getInstanceId(String name, List<InstanceId> instanceIds);
 }
diff --git 
a/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/main/java/org/apache/shardingsphere/traffic/algorithm/loadbalance/RandomTrafficLoadBalanceAlgorithm.java
 
b/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/main/java/org/apache/shardingsphere/traffic/algorithm/loadbalance/RandomTrafficLoadBalanceAlgorithm.java
index 9fd4ba6..9b164a3 100644
--- 
a/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/main/java/org/apache/shardingsphere/traffic/algorithm/loadbalance/RandomTrafficLoadBalanceAlgorithm.java
+++ 
b/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/main/java/org/apache/shardingsphere/traffic/algorithm/loadbalance/RandomTrafficLoadBalanceAlgorithm.java
@@ -17,6 +17,7 @@
 
 package org.apache.shardingsphere.traffic.algorithm.loadbalance;
 
+import org.apache.shardingsphere.infra.instance.definition.InstanceId;
 import org.apache.shardingsphere.traffic.spi.TrafficLoadBalanceAlgorithm;
 
 import java.util.List;
@@ -28,7 +29,7 @@ import java.util.concurrent.ThreadLocalRandom;
 public final class RandomTrafficLoadBalanceAlgorithm implements 
TrafficLoadBalanceAlgorithm {
     
     @Override
-    public String getInstanceId(final String name, final List<String> 
instanceIds) {
+    public InstanceId getInstanceId(final String name, final List<InstanceId> 
instanceIds) {
         return 
instanceIds.get(ThreadLocalRandom.current().nextInt(instanceIds.size()));
     }
     
diff --git 
a/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/main/java/org/apache/shardingsphere/traffic/algorithm/loadbalance/RoundRobinTrafficLoadBalanceAlgorithm.java
 
b/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/main/java/org/apache/shardingsphere/traffic/algorithm/loadbalance/RoundRobinTrafficLoadBalanceAlgorithm.java
index 509ea4a..fc925ab 100644
--- 
a/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/main/java/org/apache/shardingsphere/traffic/algorithm/loadbalance/RoundRobinTrafficLoadBalanceAlgorithm.java
+++ 
b/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/main/java/org/apache/shardingsphere/traffic/algorithm/loadbalance/RoundRobinTrafficLoadBalanceAlgorithm.java
@@ -17,6 +17,7 @@
 
 package org.apache.shardingsphere.traffic.algorithm.loadbalance;
 
+import org.apache.shardingsphere.infra.instance.definition.InstanceId;
 import org.apache.shardingsphere.traffic.spi.TrafficLoadBalanceAlgorithm;
 
 import java.util.List;
@@ -31,7 +32,7 @@ public final class RoundRobinTrafficLoadBalanceAlgorithm 
implements TrafficLoadB
     private static final ConcurrentHashMap<String, AtomicInteger> COUNTS = new 
ConcurrentHashMap<>();
     
     @Override
-    public String getInstanceId(final String name, final List<String> 
instanceIds) {
+    public InstanceId getInstanceId(final String name, final List<InstanceId> 
instanceIds) {
         AtomicInteger count = COUNTS.containsKey(name) ? COUNTS.get(name) : 
new AtomicInteger(0);
         COUNTS.putIfAbsent(name, count);
         count.compareAndSet(instanceIds.size(), 0);
diff --git 
a/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/main/java/org/apache/shardingsphere/traffic/engine/TrafficEngine.java
 
b/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/main/java/org/apache/shardingsphere/traffic/engine/TrafficEngine.java
index 1878dc5..db4daa7 100644
--- 
a/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/main/java/org/apache/shardingsphere/traffic/engine/TrafficEngine.java
+++ 
b/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/main/java/org/apache/shardingsphere/traffic/engine/TrafficEngine.java
@@ -19,15 +19,14 @@ package org.apache.shardingsphere.traffic.engine;
 
 import lombok.RequiredArgsConstructor;
 import org.apache.shardingsphere.infra.binder.LogicSQL;
-import org.apache.shardingsphere.infra.instance.ComputeNodeInstance;
 import org.apache.shardingsphere.infra.instance.InstanceContext;
+import org.apache.shardingsphere.infra.instance.definition.InstanceId;
 import org.apache.shardingsphere.infra.instance.definition.InstanceType;
 import org.apache.shardingsphere.traffic.context.TrafficContext;
 import org.apache.shardingsphere.traffic.rule.TrafficRule;
 import org.apache.shardingsphere.traffic.rule.TrafficStrategyRule;
+import org.apache.shardingsphere.traffic.spi.TrafficLoadBalanceAlgorithm;
 
-import java.util.ArrayList;
-import java.util.Collection;
 import java.util.List;
 import java.util.Optional;
 
@@ -54,11 +53,11 @@ public final class TrafficEngine {
         if (!strategyRule.isPresent() || 
isInvalidStrategyRule(strategyRule.get())) {
             return result;
         }
-        List<String> instanceIds = 
getInstanceIdsByLabels(strategyRule.get().getLabels());
+        List<InstanceId> instanceIds = 
instanceContext.getComputeNodeInstanceIds(InstanceType.PROXY, 
strategyRule.get().getLabels());
         if (!instanceIds.isEmpty()) {
-            String instanceId = 1 == instanceIds.size() 
-                    ? instanceIds.iterator().next() : 
strategyRule.get().getLoadBalancer().getInstanceId(strategyRule.get().getName(),
 instanceIds);
-            result.setInstanceId(instanceId);
+            TrafficLoadBalanceAlgorithm loadBalancer = 
strategyRule.get().getLoadBalancer();
+            InstanceId instanceId = 1 == instanceIds.size() ? 
instanceIds.iterator().next() : 
loadBalancer.getInstanceId(strategyRule.get().getName(), instanceIds);
+            result.setInstanceId(instanceId.getId());
         }
         return result;
     }
@@ -66,13 +65,4 @@ public final class TrafficEngine {
     private boolean isInvalidStrategyRule(final TrafficStrategyRule 
strategyRule) {
         return strategyRule.getLabels().isEmpty() || null == 
strategyRule.getLoadBalancer();
     }
-    
-    private List<String> getInstanceIdsByLabels(final Collection<String> 
labels) {
-        List<String> result = new ArrayList<>();
-        Collection<ComputeNodeInstance> instances = 
instanceContext.getComputeNodeInstances(InstanceType.PROXY, labels);
-        for (ComputeNodeInstance each : instances) {
-            result.add(each.getInstanceDefinition().getInstanceId().getId());
-        }
-        return result;
-    }
 }
diff --git 
a/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/main/java/org/apache/shardingsphere/traffic/rule/TrafficRule.java
 
b/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/main/java/org/apache/shardingsphere/traffic/rule/TrafficRule.java
index 557ca11..2a44bbf 100644
--- 
a/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/main/java/org/apache/shardingsphere/traffic/rule/TrafficRule.java
+++ 
b/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/main/java/org/apache/shardingsphere/traffic/rule/TrafficRule.java
@@ -43,6 +43,7 @@ import java.util.Collection;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
 import java.util.LinkedList;
 import java.util.Map;
 import java.util.Map.Entry;
@@ -108,7 +109,7 @@ public final class TrafficRule implements GlobalRule {
             result = new TrafficStrategyRule(strategyConfig.getName(), 
Collections.emptyList(), trafficAlgorithm, null);
         } else {
             TrafficLoadBalanceAlgorithm loadBalancer = 
getLoadBalancer(loadBalancers, strategyConfig.getLoadBalancerName());
-            result = new TrafficStrategyRule(strategyConfig.getName(), 
strategyConfig.getLabels(), trafficAlgorithm, loadBalancer);
+            result = new TrafficStrategyRule(strategyConfig.getName(), new 
LinkedHashSet<>(strategyConfig.getLabels()), trafficAlgorithm, loadBalancer);
         }
         return result;
     }
diff --git 
a/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/test/java/org/apache/shardingsphere/traffic/algorithm/engine/TrafficEngineTest.java
 
b/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/test/java/org/apache/shardingsphere/traffic/algorithm/engine/TrafficEngineTest.java
index 8d9c16c..7602de9 100644
--- 
a/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/test/java/org/apache/shardingsphere/traffic/algorithm/engine/TrafficEngineTest.java
+++ 
b/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/test/java/org/apache/shardingsphere/traffic/algorithm/engine/TrafficEngineTest.java
@@ -18,9 +18,8 @@
 package org.apache.shardingsphere.traffic.algorithm.engine;
 
 import org.apache.shardingsphere.infra.binder.LogicSQL;
-import org.apache.shardingsphere.infra.instance.ComputeNodeInstance;
 import org.apache.shardingsphere.infra.instance.InstanceContext;
-import org.apache.shardingsphere.infra.instance.definition.InstanceDefinition;
+import org.apache.shardingsphere.infra.instance.definition.InstanceId;
 import org.apache.shardingsphere.infra.instance.definition.InstanceType;
 import org.apache.shardingsphere.traffic.context.TrafficContext;
 import org.apache.shardingsphere.traffic.engine.TrafficEngine;
@@ -32,10 +31,10 @@ import org.junit.runner.RunWith;
 import org.mockito.Mock;
 import org.mockito.junit.MockitoJUnitRunner;
 
+import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Collection;
 import java.util.Collections;
-import java.util.LinkedList;
+import java.util.List;
 import java.util.Optional;
 
 import static org.hamcrest.CoreMatchers.is;
@@ -92,22 +91,19 @@ public final class TrafficEngineTest {
         when(trafficRule.findMatchedStrategyRule(logicSQL, 
false)).thenReturn(Optional.of(strategyRule));
         when(strategyRule.getLabels()).thenReturn(Arrays.asList("OLTP", 
"OLAP"));
         TrafficLoadBalanceAlgorithm loadBalancer = 
mock(TrafficLoadBalanceAlgorithm.class);
-        when(loadBalancer.getInstanceId("traffic", 
Arrays.asList("127.0.0.1@3307", 
"127.0.0.1@3308"))).thenReturn("127.0.0.1@3307");
+        List<InstanceId> instanceIds = mockComputeNodeInstances();
+        when(loadBalancer.getInstanceId("traffic", 
instanceIds)).thenReturn(new InstanceId("127.0.0.1@3307"));
         when(strategyRule.getLoadBalancer()).thenReturn(loadBalancer);
         when(strategyRule.getName()).thenReturn("traffic");
-        when(instanceContext.getComputeNodeInstances(InstanceType.PROXY, 
Arrays.asList("OLTP", "OLAP"))).thenReturn(mockComputeNodeInstances());
+        when(instanceContext.getComputeNodeInstanceIds(InstanceType.PROXY, 
Arrays.asList("OLTP", "OLAP"))).thenReturn(instanceIds);
         TrafficContext actual = trafficEngine.dispatch(logicSQL, false);
         assertThat(actual.getInstanceId(), is("127.0.0.1@3307"));
     }
     
-    private Collection<ComputeNodeInstance> mockComputeNodeInstances() {
-        Collection<ComputeNodeInstance> result = new LinkedList<>();
-        ComputeNodeInstance instanceOLAP = new ComputeNodeInstance(new 
InstanceDefinition(InstanceType.PROXY, "127.0.0.1@3307"));
-        instanceOLAP.setLabels(Collections.singletonList("OLAP"));
-        result.add(instanceOLAP);
-        ComputeNodeInstance instanceOLTP = new ComputeNodeInstance(new 
InstanceDefinition(InstanceType.PROXY, "127.0.0.1@3308"));
-        instanceOLTP.setLabels(Collections.singletonList("OLTP"));
-        result.add(instanceOLTP);
+    private List<InstanceId> mockComputeNodeInstances() {
+        List<InstanceId> result = new ArrayList<>();
+        result.add(new InstanceId("127.0.0.1@3307"));
+        result.add(new InstanceId("127.0.0.1@3308"));
         return result;
     }
 }
diff --git 
a/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/test/java/org/apache/shardingsphere/traffic/algorithm/loadbalance/RandomTrafficLoadBalanceAlgorithmTest.java
 
b/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/test/java/org/apache/shardingsphere/traffic/algorithm/loadbalance/RandomTrafficLoadBalanceAlgorithmTest.java
index b554177..b6d12b1 100644
--- 
a/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/test/java/org/apache/shardingsphere/traffic/algorithm/loadbalance/RandomTrafficLoadBalanceAlgorithmTest.java
+++ 
b/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/test/java/org/apache/shardingsphere/traffic/algorithm/loadbalance/RandomTrafficLoadBalanceAlgorithmTest.java
@@ -17,6 +17,7 @@
 
 package org.apache.shardingsphere.traffic.algorithm.loadbalance;
 
+import org.apache.shardingsphere.infra.instance.definition.InstanceId;
 import org.junit.Test;
 
 import java.util.Arrays;
@@ -33,7 +34,7 @@ public final class RandomTrafficLoadBalanceAlgorithmTest {
     
     @Test
     public void assertGetInstanceId() {
-        List<String> instanceIds = Arrays.asList("127.0.0.1@3307", 
"127.0.0.1@3308");
+        List<InstanceId> instanceIds = Arrays.asList(new 
InstanceId("127.0.0.1@3307"), new InstanceId("127.0.0.1@3308"));
         
assertTrue(instanceIds.contains(randomAlgorithm.getInstanceId("simple_traffic", 
instanceIds)));
         
assertTrue(instanceIds.contains(randomAlgorithm.getInstanceId("simple_traffic", 
instanceIds)));
         
assertTrue(instanceIds.contains(randomAlgorithm.getInstanceId("simple_traffic", 
instanceIds)));
diff --git 
a/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/test/java/org/apache/shardingsphere/traffic/algorithm/loadbalance/RoundRobinTrafficLoadBalanceAlgorithmTest.java
 
b/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/test/java/org/apache/shardingsphere/traffic/algorithm/loadbalance/RoundRobinTrafficLoadBalanceAlgorithmTest.java
index eba8e99..8854a91 100644
--- 
a/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/test/java/org/apache/shardingsphere/traffic/algorithm/loadbalance/RoundRobinTrafficLoadBalanceAlgorithmTest.java
+++ 
b/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/test/java/org/apache/shardingsphere/traffic/algorithm/loadbalance/RoundRobinTrafficLoadBalanceAlgorithmTest.java
@@ -17,6 +17,7 @@
 
 package org.apache.shardingsphere.traffic.algorithm.loadbalance;
 
+import org.apache.shardingsphere.infra.instance.definition.InstanceId;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
@@ -44,9 +45,9 @@ public final class RoundRobinTrafficLoadBalanceAlgorithmTest {
     
     @Test
     public void assertGetInstanceId() {
-        String instanceId1 = "127.0.0.1@3307";
-        String instanceId2 = "127.0.0.1@3308";
-        List<String> instanceIds = Arrays.asList(instanceId1, instanceId2);
+        InstanceId instanceId1 = new InstanceId("127.0.0.1@3307");
+        InstanceId instanceId2 = new InstanceId("127.0.0.1@3308");
+        List<InstanceId> instanceIds = Arrays.asList(instanceId1, instanceId2);
         assertThat(roundRobinAlgorithm.getInstanceId("simple_traffic", 
instanceIds), is(instanceId1));
         assertThat(roundRobinAlgorithm.getInstanceId("simple_traffic", 
instanceIds), is(instanceId2));
         assertThat(roundRobinAlgorithm.getInstanceId("simple_traffic", 
instanceIds), is(instanceId1));
diff --git 
a/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/test/java/org/apache/shardingsphere/traffic/rule/TrafficRuleTest.java
 
b/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/test/java/org/apache/shardingsphere/traffic/rule/TrafficRuleTest.java
index 0e233c7..a3bb4e9 100644
--- 
a/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/test/java/org/apache/shardingsphere/traffic/rule/TrafficRuleTest.java
+++ 
b/shardingsphere-kernel/shardingsphere-traffic/shardingsphere-traffic-core/src/test/java/org/apache/shardingsphere/traffic/rule/TrafficRuleTest.java
@@ -63,7 +63,7 @@ public final class TrafficRuleTest {
         Optional<TrafficStrategyRule> actual = 
trafficRule.findMatchedStrategyRule(createLogicSQL(true), false);
         assertTrue(actual.isPresent());
         assertThat(actual.get().getName(), is("sql_hint_traffic"));
-        assertThat(actual.get().getLabels(), is(Arrays.asList("OLTP", 
"OLAP")));
+        assertThat(actual.get().getLabels(), is(Sets.newHashSet("OLTP", 
"OLAP")));
         assertThat(actual.get().getTrafficAlgorithm(), 
instanceOf(SQLHintTrafficAlgorithm.class));
         assertThat(actual.get().getLoadBalancer(), 
instanceOf(RandomTrafficLoadBalanceAlgorithm.class));
     }
@@ -81,7 +81,7 @@ public final class TrafficRuleTest {
         Optional<TrafficStrategyRule> actual = 
trafficRule.findMatchedStrategyRule(createLogicSQL(false), true);
         assertTrue(actual.isPresent());
         assertThat(actual.get().getName(), is("transaction_traffic"));
-        assertThat(actual.get().getLabels(), 
is(Collections.singletonList("OLAP")));
+        assertThat(actual.get().getLabels(), is(Sets.newHashSet("OLAP")));
         assertThat(actual.get().getTrafficAlgorithm(), 
instanceOf(ProxyTrafficAlgorithm.class));
         assertThat(actual.get().getLoadBalancer(), 
instanceOf(RandomTrafficLoadBalanceAlgorithm.class));
     }
diff --git 
a/shardingsphere-proxy/shardingsphere-proxy-backend/src/test/java/org/apache/shardingsphere/proxy/backend/text/distsql/fixture/TestTrafficLoadBalanceAlgorithm.java
 
b/shardingsphere-proxy/shardingsphere-proxy-backend/src/test/java/org/apache/shardingsphere/proxy/backend/text/distsql/fixture/TestTrafficLoadBalanceAlgorithm.java
index 21a29a8..8be1730 100644
--- 
a/shardingsphere-proxy/shardingsphere-proxy-backend/src/test/java/org/apache/shardingsphere/proxy/backend/text/distsql/fixture/TestTrafficLoadBalanceAlgorithm.java
+++ 
b/shardingsphere-proxy/shardingsphere-proxy-backend/src/test/java/org/apache/shardingsphere/proxy/backend/text/distsql/fixture/TestTrafficLoadBalanceAlgorithm.java
@@ -19,6 +19,7 @@ package 
org.apache.shardingsphere.proxy.backend.text.distsql.fixture;
 
 import lombok.Getter;
 import lombok.Setter;
+import org.apache.shardingsphere.infra.instance.definition.InstanceId;
 import org.apache.shardingsphere.traffic.spi.TrafficLoadBalanceAlgorithm;
 
 import java.util.List;
@@ -36,7 +37,7 @@ public final class TestTrafficLoadBalanceAlgorithm implements 
TrafficLoadBalance
     }
     
     @Override
-    public String getInstanceId(final String name, final List<String> 
instanceIds) {
+    public InstanceId getInstanceId(final String name, final List<InstanceId> 
instanceIds) {
         return null;
     }
 }

Reply via email to