This is an automated email from the ASF dual-hosted git repository.
roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 2cb22ff5 [#719] feat(netty): Optimize allocation strategy (#739)
2cb22ff5 is described below
commit 2cb22ff5869fadf59f6013357de0827dc12f215d
Author: jokercurry <[email protected]>
AuthorDate: Mon Mar 20 22:11:31 2023 +0800
[#719] feat(netty): Optimize allocation strategy (#739)
### What changes were proposed in this pull request?
Users can choose to use netty's transmission method or grpc's through
client configuration.
### Why are the changes needed?
Fix: #719
### Does this PR introduce _any_ user-facing change?
No. However, if users want to use `netty` as a data transfer method, they
need to enable `spark.rss.client.type=GRPC_ NETTY` or
`mapreduce.rss.client.type=GRPC_ NETTY`
### How was this patch tested?
New uts.
---
.../hadoop/mapreduce/v2/app/RssMRAppMaster.java | 3 +
.../apache/spark/shuffle/RssShuffleManager.java | 3 +
.../apache/spark/shuffle/RssShuffleManager.java | 2 +
.../apache/uniffle/client/util/ClientUtils.java | 11 +++
.../org/apache/uniffle/client/ClientUtilsTest.java | 13 ++++
.../coordinator/SimpleClusterManagerTest.java | 78 ++++++++++++++++++++--
.../apache/uniffle/test/CoordinatorGrpcTest.java | 27 +++++++-
.../apache/uniffle/test/MRIntegrationTestBase.java | 3 +
.../org/apache/uniffle/server/ShuffleServer.java | 10 +++
9 files changed, 141 insertions(+), 9 deletions(-)
diff --git
a/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
b/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
index 15981aac..044eae0b 100644
---
a/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
+++
b/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
@@ -130,6 +130,9 @@ public class RssMRAppMaster extends MRAppMaster {
assignmentTags.addAll(Arrays.asList(rawTags.split(",")));
}
assignmentTags.add(Constants.SHUFFLE_SERVER_VERSION);
+ String clientType = conf.get(RssMRConfig.RSS_CLIENT_TYPE);
+ ClientUtils.validateClientType(clientType);
+ assignmentTags.add(clientType);
final ScheduledExecutorService scheduledExecutorService =
Executors.newSingleThreadScheduledExecutor(
new ThreadFactory() {
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 2ecb6f8f..61e201cb 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -263,12 +263,15 @@ public class RssShuffleManager implements ShuffleManager {
// get all register info according to coordinator's response
Set<String> assignmentTags =
RssSparkShuffleUtils.getAssignmentTags(sparkConf);
+ ClientUtils.validateClientType(clientType);
+ assignmentTags.add(clientType);
int requiredShuffleServerNumber =
RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf);
// retryInterval must bigger than `rss.server.heartbeat.timeout`, or maybe
it will return the same result
long retryInterval =
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
int retryTimes =
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
+
Map<Integer, List<ShuffleServerInfo>> partitionToServers;
try {
partitionToServers = RetryUtils.retry(() -> {
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index e70026ac..4a574bc5 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -339,6 +339,8 @@ public class RssShuffleManager implements ShuffleManager {
id.get(), defaultRemoteStorage, dynamicConfEnabled, storageType,
shuffleWriteClient);
Set<String> assignmentTags =
RssSparkShuffleUtils.getAssignmentTags(sparkConf);
+ ClientUtils.validateClientType(clientType);
+ assignmentTags.add(clientType);
int requiredShuffleServerNumber =
RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf);
diff --git
a/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java
b/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java
index 0bdf7cf0..d9b51883 100644
--- a/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java
+++ b/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java
@@ -18,12 +18,16 @@
package org.apache.uniffle.client.util;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.List;
+import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.storage.util.StorageType;
@@ -122,4 +126,11 @@ public class ClientUtils {
+ "because of the poor performance of these two types.");
}
}
+
+ public static void validateClientType(String clientType) {
+ Set<String> types =
Arrays.stream(ClientType.values()).map(Enum::name).collect(Collectors.toSet());
+ if (!types.contains(clientType)) {
+ throw new IllegalArgumentException(String.format("The value of %s should
be one of %s", clientType, types));
+ }
+ }
}
diff --git
a/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java
b/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java
index 77f9cba5..577162a4 100644
--- a/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java
+++ b/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java
@@ -134,4 +134,17 @@ public class ClientUtilsTest {
List<CompletableFuture<Boolean>> futures3 = getFutures(false);
Awaitility.await().timeout(4, TimeUnit.SECONDS).until(() ->
waitUntilDoneOrFail(futures3, true));
}
+
+ @Test
+ public void testValidateClientType() {
+ String clientType = "GRPC_NETTY";
+ ClientUtils.validateClientType(clientType);
+ clientType = "test";
+ try {
+ ClientUtils.validateClientType(clientType);
+ fail();
+ } catch (Exception e) {
+ // Ignore
+ }
+ }
}
diff --git
a/coordinator/src/test/java/org/apache/uniffle/coordinator/SimpleClusterManagerTest.java
b/coordinator/src/test/java/org/apache/uniffle/coordinator/SimpleClusterManagerTest.java
index c8dccf08..123dca4d 100644
---
a/coordinator/src/test/java/org/apache/uniffle/coordinator/SimpleClusterManagerTest.java
+++
b/coordinator/src/test/java/org/apache/uniffle/coordinator/SimpleClusterManagerTest.java
@@ -27,6 +27,7 @@ import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
+import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.apache.hadoop.conf.Configuration;
import org.junit.jupiter.api.AfterEach;
@@ -36,6 +37,8 @@ import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.uniffle.common.ClientType;
+import org.apache.uniffle.common.ServerStatus;
import org.apache.uniffle.coordinator.metric.CoordinatorMetrics;
import static org.awaitility.Awaitility.await;
@@ -48,6 +51,8 @@ public class SimpleClusterManagerTest {
private static final Logger LOG =
LoggerFactory.getLogger(SimpleClusterManagerTest.class);
private final Set<String> testTags = Sets.newHashSet("test");
+ private final Set<String> nettyTags = Sets.newHashSet("test",
ClientType.GRPC_NETTY.name());
+ private final Set<String> grpcTags = Sets.newHashSet("test",
ClientType.GRPC.name());
@BeforeEach
public void setUp() {
@@ -79,15 +84,15 @@ public class SimpleClusterManagerTest {
try (SimpleClusterManager clusterManager = new SimpleClusterManager(ssc,
new Configuration())) {
ServerNode sn1 = new ServerNode("sn1", "ip", 0, 100L, 50L, 20,
- 10, testTags, true);
+ 10, grpcTags, true);
ServerNode sn2 = new ServerNode("sn2", "ip", 0, 100L, 50L, 21,
- 10, testTags, true);
+ 10, grpcTags, true);
ServerNode sn3 = new ServerNode("sn3", "ip", 0, 100L, 50L, 20,
- 11, testTags, true);
+ 11, grpcTags, true);
clusterManager.add(sn1);
clusterManager.add(sn2);
clusterManager.add(sn3);
- List<ServerNode> serverNodes = clusterManager.getServerList(testTags);
+ List<ServerNode> serverNodes = clusterManager.getServerList(grpcTags);
assertEquals(3, serverNodes.size());
Set<String> expectedIds = Sets.newHashSet("sn1", "sn2", "sn3");
assertEquals(expectedIds,
serverNodes.stream().map(ServerNode::getId).collect(Collectors.toSet()));
@@ -98,7 +103,7 @@ public class SimpleClusterManagerTest {
sn2 = new ServerNode("sn2", "ip", 0, 100L, 50L, 21,
10, Sets.newHashSet("test", "new_tag"), true);
ServerNode sn4 = new ServerNode("sn4", "ip", 0, 100L, 51L, 20,
- 10, testTags, true);
+ 10, grpcTags, true);
clusterManager.add(sn1);
clusterManager.add(sn2);
clusterManager.add(sn4);
@@ -109,7 +114,7 @@ public class SimpleClusterManagerTest {
assertTrue(serverNodes.contains(sn4));
Map<String, Set<ServerNode>> tagToNodes = clusterManager.getTagToNodes();
- assertEquals(2, tagToNodes.size());
+ assertEquals(3, tagToNodes.size());
Set<ServerNode> newTagNodes = tagToNodes.get("new_tag");
assertEquals(2, newTagNodes.size());
@@ -124,6 +129,67 @@ public class SimpleClusterManagerTest {
}
}
+ @Test
+ public void getServerListForNettyTest() throws Exception {
+ CoordinatorConf ssc = new CoordinatorConf();
+ ssc.setLong(CoordinatorConf.COORDINATOR_HEARTBEAT_TIMEOUT, 30 * 1000L);
+ try (SimpleClusterManager clusterManager = new SimpleClusterManager(ssc,
new Configuration())) {
+
+ ServerNode sn1 = new ServerNode("sn1", "ip", 0, 100L, 50L, 20,
+ 10, nettyTags, true, ServerStatus.ACTIVE, Maps.newConcurrentMap(),
1);
+ ServerNode sn2 = new ServerNode("sn2", "ip", 0, 100L, 50L, 21,
+ 10, nettyTags, true, ServerStatus.ACTIVE, Maps.newConcurrentMap(),
1);
+ ServerNode sn3 = new ServerNode("sn3", "ip", 0, 100L, 50L, 20,
+ 11, nettyTags, true, ServerStatus.ACTIVE, Maps.newConcurrentMap(),
1);
+ ServerNode sn4 = new ServerNode("sn4", "ip", 0, 100L, 50L, 20,
+ 11, grpcTags, true);
+ clusterManager.add(sn1);
+ clusterManager.add(sn2);
+ clusterManager.add(sn3);
+ clusterManager.add(sn4);
+
+ List<ServerNode> serverNodes2 = clusterManager.getServerList(nettyTags);
+ assertEquals(3, serverNodes2.size());
+
+ List<ServerNode> serverNodes3 = clusterManager.getServerList(grpcTags);
+ assertEquals(1, serverNodes3.size());
+
+ List<ServerNode> serverNodes4 = clusterManager.getServerList(testTags);
+ assertEquals(4, serverNodes4.size());
+
+ Map<String, Set<ServerNode>> tagToNodes = clusterManager.getTagToNodes();
+ assertEquals(3, tagToNodes.size());
+
+ // tag changes
+ sn1 = new ServerNode("sn1", "ip", 0, 100L, 50L, 20,
+ 10, Sets.newHashSet("new_tag"), true, ServerStatus.ACTIVE,
Maps.newConcurrentMap(), 1);
+ sn2 = new ServerNode("sn2", "ip", 0, 100L, 50L, 21,
+ 10, Sets.newHashSet("test", "new_tag"),
+ true, ServerStatus.ACTIVE, Maps.newConcurrentMap(), 1);
+ sn4 = new ServerNode("sn4", "ip", 0, 100L, 51L, 20,
+ 10, grpcTags, true);
+ clusterManager.add(sn1);
+ clusterManager.add(sn2);
+ clusterManager.add(sn4);
+ Set<ServerNode> testTagNodesForNetty =
tagToNodes.get(ClientType.GRPC_NETTY.name());
+ assertEquals(1, testTagNodesForNetty.size());
+
+ List<ServerNode> serverNodes = clusterManager.getServerList(grpcTags);
+ assertEquals(1, serverNodes.size());
+ assertTrue(serverNodes.contains(sn4));
+
+ Set<ServerNode> newTagNodes = tagToNodes.get("new_tag");
+ assertEquals(2, newTagNodes.size());
+ assertTrue(newTagNodes.contains(sn1));
+ assertTrue(newTagNodes.contains(sn2));
+ Set<ServerNode> testTagNodes = tagToNodes.get("test");
+ assertEquals(3, testTagNodes.size());
+ assertTrue(testTagNodes.contains(sn2));
+ assertTrue(testTagNodes.contains(sn3));
+ assertTrue(testTagNodes.contains(sn4));
+ }
+ }
+
@Test
public void
testGetCorrectServerNodesWhenOneNodeRemovedAndUnhealthyNodeFound() throws
Exception {
CoordinatorConf ssc = new CoordinatorConf();
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorGrpcTest.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorGrpcTest.java
index 9b63df7c..800184b2 100644
---
a/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorGrpcTest.java
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorGrpcTest.java
@@ -17,10 +17,12 @@
package org.apache.uniffle.test;
+import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
@@ -31,10 +33,12 @@ import
org.apache.uniffle.client.request.RssApplicationInfoRequest;
import org.apache.uniffle.client.request.RssGetShuffleAssignmentsRequest;
import org.apache.uniffle.client.response.RssApplicationInfoResponse;
import org.apache.uniffle.client.response.RssGetShuffleAssignmentsResponse;
+import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.PartitionRange;
import org.apache.uniffle.common.ShuffleRegisterInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssBaseConf;
+import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.storage.StorageInfo;
import org.apache.uniffle.common.storage.StorageMedia;
@@ -121,11 +125,28 @@ public class CoordinatorGrpcTest extends
CoordinatorTestBase {
@Test
public void getShuffleAssignmentsTest() throws Exception {
- String appId = "getShuffleAssignmentsTest";
+ final String appId = "getShuffleAssignmentsTest";
CoordinatorTestUtils.waitForRegister(coordinatorClient,2);
+ // When the shuffleServerHeartbeat Test is completed before the current
test,
+ // the server's tags will be [ss_v4, GRPC_NETTY] and [ss_v4, GRPC],
respectively.
+ // We need to remove the first machine's tag from GRPC_NETTY to GRPC
+ shuffleServers.get(0).stopServer();
+ RssConf shuffleServerConf = shuffleServers.get(0).getShuffleServerConf();
+ Class<RssConf> clazz = RssConf.class;
+ Field field = clazz.getDeclaredField("settings");
+ field.setAccessible(true);
+ ((ConcurrentHashMap)
field.get(shuffleServerConf)).remove(ShuffleServerConf.NETTY_SERVER_PORT.key());
+ String storageTypeJsonSource = String.format("{\"%s\": \"ssd\"}", baseDir);
+ withEnvironmentVariables("RSS_ENV_KEY", storageTypeJsonSource).execute(()
-> {
+ ShuffleServer ss = new ShuffleServer((ShuffleServerConf)
shuffleServerConf);
+ ss.start();
+ shuffleServers.set(0, ss);
+ });
+ Thread.sleep(5000);
+ // add tag when ClientType is `GRPC`
RssGetShuffleAssignmentsRequest request = new
RssGetShuffleAssignmentsRequest(
appId, 1, 10, 4, 1,
- Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION));
+ Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION,
ClientType.GRPC.name()));
RssGetShuffleAssignmentsResponse response =
coordinatorClient.getShuffleAssignments(request);
Set<Integer> expectedStart = Sets.newHashSet(0, 4, 8);
@@ -157,7 +178,7 @@ public class CoordinatorGrpcTest extends
CoordinatorTestBase {
request = new RssGetShuffleAssignmentsRequest(
appId, 1, 10, 4, 2,
- Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION));
+ Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION,
ClientType.GRPC.name()));
response = coordinatorClient.getShuffleAssignments(request);
serverToPartitionRanges = response.getServerToPartitionRanges();
assertEquals(2, serverToPartitionRanges.size());
diff --git
a/integration-test/mr/src/test/java/org/apache/uniffle/test/MRIntegrationTestBase.java
b/integration-test/mr/src/test/java/org/apache/uniffle/test/MRIntegrationTestBase.java
index 1ceab307..ed22f953 100644
---
a/integration-test/mr/src/test/java/org/apache/uniffle/test/MRIntegrationTestBase.java
+++
b/integration-test/mr/src/test/java/org/apache/uniffle/test/MRIntegrationTestBase.java
@@ -44,6 +44,8 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
+import org.apache.uniffle.common.ClientType;
+
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
@@ -165,6 +167,7 @@ public class MRIntegrationTestBase extends
IntegrationTestBase {
jobConf.set(MRJobConfig.MAPREDUCE_APPLICATION_CLASSPATH,
"$PWD/rss.jar/" + localFile.getName() + "," +
MRJobConfig.DEFAULT_MAPREDUCE_APPLICATION_CLASSPATH);
jobConf.set(RssMRConfig.RSS_COORDINATOR_QUORUM, COORDINATOR_QUORUM);
+ jobConf.set(RssMRConfig.RSS_CLIENT_TYPE, ClientType.GRPC.name());
updateRssConfiguration(jobConf);
runMRApp(jobConf, getTestTool(), getTestArgs());
diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java
b/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java
index 3a47a81d..263884c1 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java
@@ -35,6 +35,7 @@ import org.slf4j.LoggerFactory;
import picocli.CommandLine;
import org.apache.uniffle.common.Arguments;
+import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.ServerStatus;
import org.apache.uniffle.common.exception.InvalidRequestException;
import org.apache.uniffle.common.metrics.GRPCMetrics;
@@ -248,9 +249,18 @@ public class ShuffleServer {
if (CollectionUtils.isNotEmpty(configuredTags)) {
tags.addAll(configuredTags);
}
+ tagServer();
LOG.info("Server tags: {}", tags);
}
+ private void tagServer() {
+ if (nettyServerEnabled) {
+ tags.add(ClientType.GRPC_NETTY.name());
+ } else {
+ tags.add(ClientType.GRPC.name());
+ }
+ }
+
private void registerMetrics() throws Exception {
LOG.info("Register metrics");
CollectorRegistry shuffleServerCollectorRegistry = new
CollectorRegistry(true);