This is an automated email from the ASF dual-hosted git repository.
roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 9a227da [Improvement] Introduce config to customize assignment server
numbers in client side (#100)
9a227da is described below
commit 9a227da1d4f6de2f2267225ae1d547414be5e234
Author: Junfan Zhang <[email protected]>
AuthorDate: Mon Aug 1 20:12:17 2022 +0800
[Improvement] Introduce config to customize assignment server numbers in
client side (#100)
### What changes were proposed in this pull request?
[Improvement] Introduce config to customize assignment server numbers in
client side.
**Changelog**
1. Introduce the config of
`<client_type>.rss.client.assignment.shuffle.nodes.max`
### Why are the changes needed?
Now the assignment number specified by coordinator's conf of
rss.coordinator.shuffle.nodes.max. But i think it's not suitable for all spark
jobs.
We should introduce new config to let client specify the assignment server
number. rss.coordinator.shuffle.nodes.max should be as a max limitation of
clients' number.
### Does this PR introduce _any_ user-facing change?
YES.
### How was this patch tested?
UT.
---
.../org/apache/hadoop/mapreduce/RssMRConfig.java | 7 +-
.../hadoop/mapreduce/v2/app/RssMRAppMaster.java | 17 ++-
.../hadoop/mapred/SortWriteBufferManagerTest.java | 2 +-
.../hadoop/mapreduce/task/reduce/FetcherTest.java | 2 +-
.../org/apache/spark/shuffle/RssSparkConfig.java | 4 +
.../apache/spark/shuffle/RssShuffleManager.java | 4 +-
.../apache/spark/shuffle/RssShuffleManager.java | 5 +-
.../uniffle/client/api/ShuffleWriteClient.java | 2 +-
.../client/impl/ShuffleWriteClientImpl.java | 4 +-
.../uniffle/client/util/RssClientConfig.java | 4 +
.../uniffle/coordinator/AssignmentStrategy.java | 2 +-
.../coordinator/BasicAssignmentStrategy.java | 8 +-
.../uniffle/coordinator/CoordinatorConf.java | 2 +-
.../coordinator/CoordinatorGrpcService.java | 10 +-
.../PartitionBalanceAssignmentStrategy.java | 12 +-
.../coordinator/BasicAssignmentStrategyTest.java | 104 ++++++++++++++++-
.../PartitionBalanceAssignmentStrategyTest.java | 124 ++++++++++++++++++---
docs/client_guide.md | 1 +
.../test/AssignmentServerNodesNumberTest.java | 106 ++++++++++++++++++
.../uniffle/test/AssignmentWithTagsTest.java | 10 +-
.../client/impl/grpc/CoordinatorGrpcClient.java | 12 +-
.../request/RssGetShuffleAssignmentsRequest.java | 14 +++
proto/src/main/proto/Rss.proto | 1 +
23 files changed, 409 insertions(+), 48 deletions(-)
diff --git
a/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java
b/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java
index ed1f90f..ef47e21 100644
--- a/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java
+++ b/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java
@@ -144,7 +144,12 @@ public class RssMRConfig {
public static final int RSS_ACCESS_TIMEOUT_MS_DEFAULT_VALUE =
RssClientConfig.RSS_ACCESS_TIMEOUT_MS_DEFAULT_VALUE;
public static final String RSS_CLIENT_ASSIGNMENT_TAGS =
- MR_RSS_CONFIG_PREFIX + RssClientConfig.RSS_CLIENT_ASSIGNMENT_TAGS;
+ MR_RSS_CONFIG_PREFIX + RssClientConfig.RSS_CLIENT_ASSIGNMENT_TAGS;
+
+ public static final String RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER =
+ RssClientConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER;
+ public static final int
RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER_DEFAULT_VALUE =
+
RssClientConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER_DEFAULT_VALUE;
public static final String RSS_CONF_FILE = "rss_conf.xml";
diff --git
a/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
b/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
index e163eec..c65f2a2 100644
---
a/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
+++
b/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
@@ -128,10 +128,23 @@ public class RssMRAppMaster extends MRAppMaster {
}
assignmentTags.add(Constants.SHUFFLE_SERVER_VERSION);
+ int requiredAssignmentShuffleServersNum = conf.getInt(
+ RssMRConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER,
+ RssMRConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER_DEFAULT_VALUE
+ );
+
ApplicationAttemptId applicationAttemptId =
RssMRUtils.getApplicationAttemptId();
String appId = applicationAttemptId.toString();
- ShuffleAssignmentsInfo response = client.getShuffleAssignments(
- appId, 0, numReduceTasks, 1, Sets.newHashSet(assignmentTags));
+
+ ShuffleAssignmentsInfo response =
+ client.getShuffleAssignments(
+ appId,
+ 0,
+ numReduceTasks,
+ 1,
+ Sets.newHashSet(assignmentTags),
+ requiredAssignmentShuffleServersNum
+ );
Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges =
response.getServerToPartitionRanges();
final ScheduledExecutorService scheduledExecutorService =
Executors.newSingleThreadScheduledExecutor(
diff --git
a/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
b/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
index 398b52a..7613883 100644
---
a/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
+++
b/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
@@ -314,7 +314,7 @@ public class SortWriteBufferManagerTest {
}
@Override
- public ShuffleAssignmentsInfo getShuffleAssignments(String appId, int
shuffleId, int partitionNum, int partitionNumPerRange, Set<String>
requiredTags) {
+ public ShuffleAssignmentsInfo getShuffleAssignments(String appId, int
shuffleId, int partitionNum, int partitionNumPerRange, Set<String>
requiredTags, int assignmentShuffleServerNumber) {
return null;
}
diff --git
a/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
b/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
index ee2539b..cc622f0 100644
---
a/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
+++
b/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
@@ -401,7 +401,7 @@ public class FetcherTest {
}
@Override
- public ShuffleAssignmentsInfo getShuffleAssignments(String appId, int
shuffleId, int partitionNum, int partitionNumPerRange, Set<String>
requiredTags) {
+ public ShuffleAssignmentsInfo getShuffleAssignments(String appId, int
shuffleId, int partitionNum, int partitionNumPerRange, Set<String>
requiredTags, int assignmentShuffleServerNumber) {
return null;
}
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index 875c9a5..6b549b1 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -203,6 +203,10 @@ public class RssSparkConfig {
+ "whether this conf is set or not"))
.createWithDefault("");
+ public static final ConfigEntry<Integer>
RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER = createIntegerBuilder(
+ new ConfigBuilder(SPARK_RSS_CONFIG_PREFIX +
RssClientConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER))
+
.createWithDefault(RssClientConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER_DEFAULT_VALUE);
+
public static final ConfigEntry<String> RSS_COORDINATOR_QUORUM =
createStringBuilder(
new ConfigBuilder(SPARK_RSS_CONFIG_PREFIX +
RssClientConfig.RSS_COORDINATOR_QUORUM)
.doc("Coordinator quorum"))
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index c313747..ec84308 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -218,9 +218,11 @@ public class RssShuffleManager implements ShuffleManager {
// get all register info according to coordinator's response
Set<String> assignmentTags =
RssSparkShuffleUtils.getAssignmentTags(sparkConf);
+ int requiredShuffleServerNumber =
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER);
+
ShuffleAssignmentsInfo response = shuffleWriteClient.getShuffleAssignments(
appId, shuffleId, dependency.partitioner().numPartitions(),
- partitionNumPerRange, assignmentTags);
+ partitionNumPerRange, assignmentTags, requiredShuffleServerNumber);
Map<Integer, List<ShuffleServerInfo>> partitionToServers =
response.getPartitionToServers();
startHeartbeat();
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 32239b3..80fac99 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -256,12 +256,15 @@ public class RssShuffleManager implements ShuffleManager {
Set<String> assignmentTags =
RssSparkShuffleUtils.getAssignmentTags(sparkConf);
+ int requiredShuffleServerNumber =
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER);
+
ShuffleAssignmentsInfo response = shuffleWriteClient.getShuffleAssignments(
id.get(),
shuffleId,
dependency.partitioner().numPartitions(),
1,
- assignmentTags);
+ assignmentTags,
+ requiredShuffleServerNumber);
Map<Integer, List<ShuffleServerInfo>> partitionToServers =
response.getPartitionToServers();
startHeartbeat();
diff --git
a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
index 2cf3685..d5981c4 100644
--- a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
+++ b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
@@ -60,7 +60,7 @@ public interface ShuffleWriteClient {
int bitmapNum);
ShuffleAssignmentsInfo getShuffleAssignments(String appId, int shuffleId,
int partitionNum,
- int partitionNumPerRange, Set<String> requiredTags);
+ int partitionNumPerRange, Set<String> requiredTags, int
assignmentShuffleServerNumber);
Roaring64NavigableMap getShuffleResult(String clientType,
Set<ShuffleServerInfo> shuffleServerInfoSet,
String appId, int shuffleId, int partitionId);
diff --git
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
index 43d8b3b..c6c13e0 100644
---
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
+++
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
@@ -375,9 +375,9 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
@Override
public ShuffleAssignmentsInfo getShuffleAssignments(String appId, int
shuffleId, int partitionNum,
- int partitionNumPerRange, Set<String> requiredTags) {
+ int partitionNumPerRange, Set<String> requiredTags, int
assignmentShuffleServerNumber) {
RssGetShuffleAssignmentsRequest request = new
RssGetShuffleAssignmentsRequest(
- appId, shuffleId, partitionNum, partitionNumPerRange, replica,
requiredTags);
+ appId, shuffleId, partitionNum, partitionNumPerRange, replica,
requiredTags, assignmentShuffleServerNumber);
RssGetShuffleAssignmentsResponse response = new
RssGetShuffleAssignmentsResponse(ResponseStatusCode.INTERNAL_ERROR);
for (CoordinatorClient coordinatorClient : coordinatorClients) {
diff --git
a/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java
b/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java
index 0b42d49..eb6006a 100644
--- a/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java
+++ b/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java
@@ -65,4 +65,8 @@ public class RssClientConfig {
public static final int RSS_ACCESS_TIMEOUT_MS_DEFAULT_VALUE = 10000;
public static final String RSS_DYNAMIC_CLIENT_CONF_ENABLED =
"rss.dynamicClientConf.enabled";
public static final boolean RSS_DYNAMIC_CLIENT_CONF_ENABLED_DEFAULT_VALUE =
true;
+
+ public static final String RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER =
+ "rss.client.assignment.shuffle.nodes.max";
+ public static final int
RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER_DEFAULT_VALUE = -1;
}
diff --git
a/coordinator/src/main/java/org/apache/uniffle/coordinator/AssignmentStrategy.java
b/coordinator/src/main/java/org/apache/uniffle/coordinator/AssignmentStrategy.java
index 86ddb18..36d1908 100644
---
a/coordinator/src/main/java/org/apache/uniffle/coordinator/AssignmentStrategy.java
+++
b/coordinator/src/main/java/org/apache/uniffle/coordinator/AssignmentStrategy.java
@@ -22,6 +22,6 @@ import java.util.Set;
public interface AssignmentStrategy {
PartitionRangeAssignment assign(int totalPartitionNum, int
partitionNumPerRange,
- int replica, Set<String> requiredTags);
+ int replica, Set<String> requiredTags, int requiredShuffleServerNumber);
}
diff --git
a/coordinator/src/main/java/org/apache/uniffle/coordinator/BasicAssignmentStrategy.java
b/coordinator/src/main/java/org/apache/uniffle/coordinator/BasicAssignmentStrategy.java
index 8d4eb54..54ca2c2 100644
---
a/coordinator/src/main/java/org/apache/uniffle/coordinator/BasicAssignmentStrategy.java
+++
b/coordinator/src/main/java/org/apache/uniffle/coordinator/BasicAssignmentStrategy.java
@@ -41,10 +41,14 @@ public class BasicAssignmentStrategy implements
AssignmentStrategy {
@Override
public PartitionRangeAssignment assign(int totalPartitionNum, int
partitionNumPerRange,
- int replica, Set<String> requiredTags) {
+ int replica, Set<String> requiredTags, int requiredShuffleServerNumber) {
List<PartitionRange> ranges =
CoordinatorUtils.generateRanges(totalPartitionNum, partitionNumPerRange);
int shuffleNodesMax = clusterManager.getShuffleNodesMax();
- List<ServerNode> servers = getRequiredServers(requiredTags,
shuffleNodesMax);
+ int expectedShuffleNodesNum = shuffleNodesMax;
+ if (requiredShuffleServerNumber < shuffleNodesMax &&
requiredShuffleServerNumber > 0) {
+ expectedShuffleNodesNum = requiredShuffleServerNumber;
+ }
+ List<ServerNode> servers = getRequiredServers(requiredTags,
expectedShuffleNodesNum);
if (servers.isEmpty() || servers.size() < replica) {
return new PartitionRangeAssignment(null);
}
diff --git
a/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorConf.java
b/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorConf.java
index bdad2d6..ebb50ce 100644
---
a/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorConf.java
+++
b/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorConf.java
@@ -61,7 +61,7 @@ public class CoordinatorConf extends RssBaseConf {
.key("rss.coordinator.shuffle.nodes.max")
.intType()
.defaultValue(9)
- .withDescription("The max number of shuffle server when do the
assignment");
+ .withDescription("The max limitation number of shuffle server when do
the assignment");
public static final ConfigOption<List<String>> COORDINATOR_ACCESS_CHECKERS =
ConfigOptions
.key("rss.coordinator.access.checkers")
.stringType()
diff --git
a/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorGrpcService.java
b/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorGrpcService.java
index ce14458..d2c3fc2 100644
---
a/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorGrpcService.java
+++
b/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorGrpcService.java
@@ -109,15 +109,21 @@ public class CoordinatorGrpcService extends
CoordinatorServerGrpc.CoordinatorSer
final int partitionNumPerRange = request.getPartitionNumPerRange();
final int replica = request.getDataReplica();
final Set<String> requiredTags =
Sets.newHashSet(request.getRequireTagsList());
+ final int requiredShuffleServerNumber =
request.getAssignmentShuffleServerNumber();
LOG.info("Request of getShuffleAssignments for appId[" + appId
+ "], shuffleId[" + shuffleId + "], partitionNum[" + partitionNum
- + "], partitionNumPerRange[" + partitionNumPerRange + "], replica[" +
replica + "]");
+ + "], partitionNumPerRange[" + partitionNumPerRange + "], replica[" +
replica
+ + "], requiredTags[" + requiredTags
+ + "], requiredShuffleServerNumber[" + requiredShuffleServerNumber + "]"
+ );
GetShuffleAssignmentsResponse response;
try {
final PartitionRangeAssignment pra =
- coordinatorServer.getAssignmentStrategy().assign(partitionNum,
partitionNumPerRange, replica, requiredTags);
+ coordinatorServer
+ .getAssignmentStrategy()
+ .assign(partitionNum, partitionNumPerRange, replica,
requiredTags, requiredShuffleServerNumber);
response =
CoordinatorUtils.toGetShuffleAssignmentsResponse(pra);
logAssignmentResult(appId, shuffleId, pra);
diff --git
a/coordinator/src/main/java/org/apache/uniffle/coordinator/PartitionBalanceAssignmentStrategy.java
b/coordinator/src/main/java/org/apache/uniffle/coordinator/PartitionBalanceAssignmentStrategy.java
index ba92477..d074b8c 100644
---
a/coordinator/src/main/java/org/apache/uniffle/coordinator/PartitionBalanceAssignmentStrategy.java
+++
b/coordinator/src/main/java/org/apache/uniffle/coordinator/PartitionBalanceAssignmentStrategy.java
@@ -66,7 +66,8 @@ public class PartitionBalanceAssignmentStrategy implements
AssignmentStrategy {
int totalPartitionNum,
int partitionNumPerRange,
int replica,
- Set<String> requiredTags) {
+ Set<String> requiredTags,
+ int requiredShuffleServerNumber) {
if (partitionNumPerRange != 1) {
throw new RuntimeException("PartitionNumPerRange must be one");
@@ -107,8 +108,13 @@ public class PartitionBalanceAssignmentStrategy implements
AssignmentStrategy {
throw new RuntimeException("There isn't enough shuffle servers");
}
- int expectNum = clusterManager.getShuffleNodesMax();
- if (nodes.size() < clusterManager.getShuffleNodesMax()) {
+ final int assignmentMaxNum = clusterManager.getShuffleNodesMax();
+ int expectNum = assignmentMaxNum;
+ if (requiredShuffleServerNumber < assignmentMaxNum &&
requiredShuffleServerNumber > 0) {
+ expectNum = requiredShuffleServerNumber;
+ }
+
+ if (nodes.size() < expectNum) {
LOG.warn("Can't get expected servers [" + expectNum + "] and found
only [" + nodes.size() + "]");
expectNum = nodes.size();
}
diff --git
a/coordinator/src/test/java/org/apache/uniffle/coordinator/BasicAssignmentStrategyTest.java
b/coordinator/src/test/java/org/apache/uniffle/coordinator/BasicAssignmentStrategyTest.java
index 6f80eb3..a1f79cf 100644
---
a/coordinator/src/test/java/org/apache/uniffle/coordinator/BasicAssignmentStrategyTest.java
+++
b/coordinator/src/test/java/org/apache/uniffle/coordinator/BasicAssignmentStrategyTest.java
@@ -22,6 +22,8 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.google.common.collect.Sets;
+import java.util.Collection;
+import java.util.stream.Collectors;
import org.apache.uniffle.common.PartitionRange;
import java.io.IOException;
@@ -64,7 +66,7 @@ public class BasicAssignmentStrategyTest {
20 - i, 0, tags, true));
}
- PartitionRangeAssignment pra = strategy.assign(100, 10, 2, tags);
+ PartitionRangeAssignment pra = strategy.assign(100, 10, 2, tags, -1);
SortedMap<PartitionRange, List<ServerNode>> assignments =
pra.getAssignments();
assertEquals(10, assignments.size());
@@ -90,14 +92,14 @@ public class BasicAssignmentStrategyTest {
clusterManager.add(new ServerNode(String.valueOf(i), "", 0, 0, 0,
0, 0, tags, true));
}
- PartitionRangeAssignment pra = strategy.assign(100, 10, 2, tags);
+ PartitionRangeAssignment pra = strategy.assign(100, 10, 2, tags, -1);
SortedMap<PartitionRange, List<ServerNode>> assignments =
pra.getAssignments();
Set<ServerNode> serverNodes1 = Sets.newHashSet();
for (Map.Entry<PartitionRange, List<ServerNode>> assignment :
assignments.entrySet()) {
serverNodes1.addAll(assignment.getValue());
}
- pra = strategy.assign(100, 10, 2, tags);
+ pra = strategy.assign(100, 10, 2, tags, -1);
assignments = pra.getAssignments();
Set<ServerNode> serverNodes2 = Sets.newHashSet();
for (Map.Entry<PartitionRange, List<ServerNode>> assignment :
assignments.entrySet()) {
@@ -118,13 +120,13 @@ public class BasicAssignmentStrategyTest {
0, 0, tags, true);
clusterManager.add(sn1);
- PartitionRangeAssignment pra = strategy.assign(100, 10, 2, tags);
+ PartitionRangeAssignment pra = strategy.assign(100, 10, 2, tags, -1);
// nodeNum < replica
assertNull(pra.getAssignments());
// nodeNum = replica
clusterManager.add(sn2);
- pra = strategy.assign(100, 10, 2, tags);
+ pra = strategy.assign(100, 10, 2, tags, -1);
SortedMap<PartitionRange, List<ServerNode>> assignments =
pra.getAssignments();
Set<ServerNode> serverNodes = Sets.newHashSet();
for (Map.Entry<PartitionRange, List<ServerNode>> assignment :
assignments.entrySet()) {
@@ -136,7 +138,7 @@ public class BasicAssignmentStrategyTest {
// nodeNum > replica & nodeNum < shuffleNodesMax
clusterManager.add(sn3);
- pra = strategy.assign(100, 10, 2, tags);
+ pra = strategy.assign(100, 10, 2, tags, -1);
assignments = pra.getAssignments();
serverNodes = Sets.newHashSet();
for (Map.Entry<PartitionRange, List<ServerNode>> assignment :
assignments.entrySet()) {
@@ -147,4 +149,94 @@ public class BasicAssignmentStrategyTest {
assertTrue(serverNodes.contains(sn2));
assertTrue(serverNodes.contains(sn3));
}
+
+ @Test
+ public void testAssignmentShuffleNodesNum() {
+ Set<String> serverTags = Sets.newHashSet("tag-1");
+
+ for (int i = 0; i < 20; ++i) {
+ clusterManager.add(new ServerNode("t1-" + i, "", 0, 0, 0,
+ 20 - i, 0, serverTags, true));
+ }
+
+ /**
+ * case1: user specify the illegal shuffle node num(<0)
+ * it will use the default shuffle nodes num when having enough servers.
+ */
+ PartitionRangeAssignment pra = strategy.assign(100, 10, 1, serverTags, -1);
+ assertEquals(
+ shuffleNodesMax,
+ pra.getAssignments()
+ .values()
+ .stream()
+ .flatMap(Collection::stream)
+ .collect(Collectors.toSet())
+ .size()
+ );
+
+ /**
+ * case2: user specify the illegal shuffle node num(==0)
+ * it will use the default shuffle nodes num when having enough servers.
+ */
+ pra = strategy.assign(100, 10, 1, serverTags, 0);
+ assertEquals(
+ shuffleNodesMax,
+ pra.getAssignments()
+ .values()
+ .stream()
+ .flatMap(Collection::stream)
+ .collect(Collectors.toSet())
+ .size()
+ );
+
+ /**
+ * case3: user specify the illegal shuffle node num(>default max
limitation)
+ * it will use the default shuffle nodes num when having enough servers
+ */
+ pra = strategy.assign(100, 10, 1, serverTags, shuffleNodesMax + 10);
+ assertEquals(
+ shuffleNodesMax,
+ pra.getAssignments()
+ .values()
+ .stream()
+ .flatMap(Collection::stream)
+ .collect(Collectors.toSet())
+ .size()
+ );
+
+ /**
+ * case4: user specify the legal shuffle node num,
+ * it will use the customized shuffle nodes num when having enough servers
+ */
+ pra = strategy.assign(100, 10, 1, serverTags, shuffleNodesMax - 1);
+ assertEquals(
+ shuffleNodesMax - 1,
+ pra.getAssignments()
+ .values()
+ .stream()
+ .flatMap(Collection::stream)
+ .collect(Collectors.toSet())
+ .size()
+ );
+
+ /**
+ * case5: user specify the legal shuffle node num, but cluster dont have
enough servers,
+ * it will return the remaining servers.
+ */
+ serverTags = Sets.newHashSet("tag-2");
+ for (int i = 0; i < shuffleNodesMax - 1; ++i) {
+ clusterManager.add(new ServerNode("t2-" + i, "", 0, 0, 0,
+ 20 - i, 0, serverTags, true));
+ }
+ pra = strategy.assign(100, 10, 1, serverTags, shuffleNodesMax);
+ assertEquals(
+ shuffleNodesMax - 1,
+ pra.getAssignments()
+ .values()
+ .stream()
+ .flatMap(Collection::stream)
+ .collect(Collectors.toSet())
+ .size()
+ );
+ }
}
diff --git
a/coordinator/src/test/java/org/apache/uniffle/coordinator/PartitionBalanceAssignmentStrategyTest.java
b/coordinator/src/test/java/org/apache/uniffle/coordinator/PartitionBalanceAssignmentStrategyTest.java
index ae3b6e3..47cc26f 100644
---
a/coordinator/src/test/java/org/apache/uniffle/coordinator/PartitionBalanceAssignmentStrategyTest.java
+++
b/coordinator/src/test/java/org/apache/uniffle/coordinator/PartitionBalanceAssignmentStrategyTest.java
@@ -18,6 +18,7 @@
package org.apache.uniffle.coordinator;
import java.io.IOException;
+import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Set;
@@ -27,6 +28,7 @@ import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.Uninterruptibles;
+import java.util.stream.Collectors;
import org.apache.hadoop.conf.Configuration;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
@@ -60,32 +62,32 @@ public class PartitionBalanceAssignmentStrategyTest {
updateServerResource(list);
boolean isThrown = false;
try {
- strategy.assign(100, 2, 1, tags);
+ strategy.assign(100, 2, 1, tags, -1);
} catch (Exception e) {
isThrown = true;
}
assertTrue(isThrown);
try {
- strategy.assign(0, 1, 1, tags);
+ strategy.assign(0, 1, 1, tags, -1);
} catch (Exception e) {
fail();
}
isThrown = false;
try {
- strategy.assign(10, 1, 1, Sets.newHashSet("fake"));
+ strategy.assign(10, 1, 1, Sets.newHashSet("fake"), 1);
} catch (Exception e) {
isThrown = true;
}
assertTrue(isThrown);
- strategy.assign(100, 1, 1, tags);
+ strategy.assign(100, 1, 1, tags, -1);
List<Long> expect = Lists.newArrayList(20L, 20L, 20L, 20L, 20L, 0L, 0L,
0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L);
valid(expect);
- strategy.assign(75, 1, 1, tags);
+ strategy.assign(75, 1, 1, tags, -1);
expect = Lists.newArrayList(20L, 20L, 20L, 20L, 20L, 15L, 15L, 15L, 15L,
15L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L);
valid(expect);
- strategy.assign(100, 1, 1, tags);
+ strategy.assign(100, 1, 1, tags, -1);
expect = Lists.newArrayList(20L, 20L, 20L, 20L, 20L, 15L, 15L, 15L, 15L,
15L, 20L, 20L, 20L, 20L, 20L, 0L, 0L, 0L, 0L, 0L);
valid(expect);
@@ -94,16 +96,16 @@ public class PartitionBalanceAssignmentStrategyTest {
list = Lists.newArrayList(7L, 18L, 7L, 3L, 19L, 15L, 11L, 10L, 16L, 11L,
14L, 17L, 15L, 17L, 8L, 1L, 3L, 3L, 6L, 12L);
updateServerResource(list);
- strategy.assign(100, 1, 1, tags);
+ strategy.assign(100, 1, 1, tags, -1);
expect = Lists.newArrayList(0L, 20L, 0L, 0L, 20L, 0L, 0L, 0L, 20L, 0L,
0L, 20L, 0L, 20L, 0L, 0L, 0L, 0L, 0L, 0L);
valid(expect);
- strategy.assign(50, 1, 1, tags);
+ strategy.assign(50, 1, 1, tags, -1);
expect = Lists.newArrayList(0L, 20L, 0L, 0L, 20L, 10L, 10L, 0L, 20L, 0L,
10L, 20L, 10L, 20L, 0L, 0L, 0L, 0L, 0L, 10L);
valid(expect);
- strategy.assign(75, 1, 1, tags);
+ strategy.assign(75, 1, 1, tags, -1);
expect = Lists.newArrayList(0L, 20L, 0L, 0L, 20L, 25L, 10L, 15L, 20L, 15L,
25L, 20L, 25L, 20L, 0L, 0L, 0L, 0L, 0L, 10L);
valid(expect);
@@ -112,15 +114,15 @@ public class PartitionBalanceAssignmentStrategyTest {
list = Lists.newArrayList(7L, 18L, 7L, 3L, 19L, 15L, 11L, 10L, 16L, 11L,
14L, 17L, 15L, 17L, 8L, 1L, 3L, 3L, 6L, 12L);
updateServerResource(list);
- strategy.assign(50, 1, 2, tags);
+ strategy.assign(50, 1, 2, tags, -1);
expect = Lists.newArrayList(0L, 20L, 0L, 0L, 20L, 0L, 0L, 0L, 20L, 0L,
0L, 20L, 0L, 20L, 0L, 0L, 0L, 0L, 0L, 0L);
valid(expect);
- strategy.assign(75, 1, 2, tags);
+ strategy.assign(75, 1, 2, tags, -1);
expect = Lists.newArrayList(0L, 20L, 0L, 0L, 50L, 30L, 0L, 0L, 20L, 0L,
30L, 20L, 30L, 20L, 0L, 0L, 0L, 0L, 0L, 30L);
valid(expect);
- strategy.assign(33, 1, 2, tags);
+ strategy.assign(33, 1, 2, tags, -1);
expect = Lists.newArrayList(0L, 33L, 0L, 0L, 50L, 30L, 14L, 13L, 20L, 13L,
30L, 20L, 30L, 20L, 13L, 0L, 0L, 0L, 0L, 30L);
valid(expect);
@@ -136,19 +138,19 @@ public class PartitionBalanceAssignmentStrategyTest {
Uninterruptibles.sleepUninterruptibly(10, TimeUnit.MILLISECONDS);
updateServerResource(list);
- strategy.assign(33, 1, 1, tags);
+ strategy.assign(33, 1, 1, tags, -1);
expect = Lists.newArrayList(0L, 7L, 0L, 7L, 0L, 7L, 0L, 6L, 0L, 6L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L);
valid(expect);
- strategy.assign(41, 1, 2, tags);
+ strategy.assign(41, 1, 2, tags, -1);
expect = Lists.newArrayList(0L, 7L, 0L, 7L, 0L, 7L, 0L, 6L, 0L, 6L, 0L,
17L,
0L, 17L, 0L, 16L, 0L, 16L, 0L, 16L);
valid(expect);
- strategy.assign(23, 1, 1, tags);
+ strategy.assign(23, 1, 1, tags, -1);
expect = Lists.newArrayList(5L, 7L, 5L, 7L, 5L, 7L, 4L, 6L, 4L, 6L, 0L,
17L,
0L, 17L, 0L, 16L, 0L, 16L, 0L, 16L);
valid(expect);
- strategy.assign(11, 1, 3, tags);
+ strategy.assign(11, 1, 3, tags, -1);
expect = Lists.newArrayList(5L, 7L, 5L, 7L, 5L, 7L, 4L, 13L, 4L, 13L, 7L,
17L,
6L, 17L, 6L, 16L, 0L, 16L, 0L, 16L);
valid(expect);
@@ -191,4 +193,94 @@ public class PartitionBalanceAssignmentStrategyTest {
clusterManager.add(node);
}
}
+
+ @Test
+ public void testAssignmentShuffleNodesNum() {
+ Set<String> serverTags = Sets.newHashSet("tag-1");
+
+ for (int i = 0; i < 20; ++i) {
+ clusterManager.add(new ServerNode("t1-" + i, "", 0, 0, 0,
+ 20 - i, 0, serverTags, true));
+ }
+
+ /**
+ * case1: user specify the illegal shuffle node num(<0)
+ * it will use the default shuffle nodes num when having enough servers.
+ */
+ PartitionRangeAssignment pra = strategy.assign(100, 1, 1, serverTags, -1);
+ assertEquals(
+ shuffleNodesMax,
+ pra.getAssignments()
+ .values()
+ .stream()
+ .flatMap(Collection::stream)
+ .collect(Collectors.toSet())
+ .size()
+ );
+
+ /**
+ * case2: user specify the illegal shuffle node num(==0)
+ * it will use the default shuffle nodes num when having enough servers.
+ */
+ pra = strategy.assign(100, 1, 1, serverTags, 0);
+ assertEquals(
+ shuffleNodesMax,
+ pra.getAssignments()
+ .values()
+ .stream()
+ .flatMap(Collection::stream)
+ .collect(Collectors.toSet())
+ .size()
+ );
+
+ /**
+ * case3: user specify the illegal shuffle node num(>default max
limitation)
+ * it will use the default shuffle nodes num when having enough servers
+ */
+ pra = strategy.assign(100, 1, 1, serverTags, shuffleNodesMax + 10);
+ assertEquals(
+ shuffleNodesMax,
+ pra.getAssignments()
+ .values()
+ .stream()
+ .flatMap(Collection::stream)
+ .collect(Collectors.toSet())
+ .size()
+ );
+
+ /**
+ * case4: user specify the legal shuffle node num,
+ * it will use the customized shuffle nodes num when having enough servers
+ */
+ pra = strategy.assign(100, 1, 1, serverTags, shuffleNodesMax - 1);
+ assertEquals(
+ shuffleNodesMax - 1,
+ pra.getAssignments()
+ .values()
+ .stream()
+ .flatMap(Collection::stream)
+ .collect(Collectors.toSet())
+ .size()
+ );
+
+ /**
+ * case5: user specify the legal shuffle node num, but cluster dont have
enough servers,
+ * it will return the remaining servers.
+ */
+ serverTags = Sets.newHashSet("tag-2");
+ for (int i = 0; i < shuffleNodesMax - 1; ++i) {
+ clusterManager.add(new ServerNode("t2-" + i, "", 0, 0, 0,
+ 20 - i, 0, serverTags, true));
+ }
+ pra = strategy.assign(100, 1, 1, serverTags, shuffleNodesMax);
+ assertEquals(
+ shuffleNodesMax - 1,
+ pra.getAssignments()
+ .values()
+ .stream()
+ .flatMap(Collection::stream)
+ .collect(Collectors.toSet())
+ .size()
+ );
+ }
}
diff --git a/docs/client_guide.md b/docs/client_guide.md
index eee239f..b97474e 100644
--- a/docs/client_guide.md
+++ b/docs/client_guide.md
@@ -88,6 +88,7 @@ These configurations are shared by all types of clients.
|<client_type>.rss.client.send.threadPool.size|5|The thread size for send
shuffle data to shuffle server|
|<client_type>.rss.client.assignment.tags|-|The comma-separated list of tags
for deciding assignment shuffle servers. Notice that the SHUFFLE_SERVER_VERSION
will always as the assignment tag whether this conf is set or not|
|<client_type>.rss.client.data.commit.pool.size|The number of assigned shuffle
servers|The thread size for sending commit to shuffle servers|
+|<client_type>.rss.client.assignment.shuffle.nodes.max|-1|The number of
required assignment shuffle servers. If it is less than 0 or equals to 0 or
greater than the coordinator's config of "rss.coordinator.shuffle.nodes.max",
it will use the size of "rss.coordinator.shuffle.nodes.max" default|
Notice:
1. `<client_type>` should be `spark` or `mapreduce`
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentServerNodesNumberTest.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentServerNodesNumberTest.java
new file mode 100644
index 0000000..57bf341
--- /dev/null
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentServerNodesNumberTest.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import com.google.common.collect.Sets;
+import com.google.common.io.Files;
+import java.io.File;
+import java.util.ArrayList;
+import java.util.HashSet;
+import org.apache.uniffle.client.impl.ShuffleWriteClientImpl;
+import org.apache.uniffle.client.util.ClientType;
+import org.apache.uniffle.common.ShuffleAssignmentsInfo;
+import org.apache.uniffle.common.config.RssBaseConf;
+import org.apache.uniffle.coordinator.CoordinatorConf;
+import org.apache.uniffle.server.ShuffleServerConf;
+import org.apache.uniffle.storage.util.StorageType;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class AssignmentServerNodesNumberTest extends CoordinatorTestBase {
+ private static final Logger LOG =
LoggerFactory.getLogger(AssignmentServerNodesNumberTest.class);
+ private static final int SHUFFLE_NODES_MAX = 10;
+ private static final int SERVER_NUM = 10;
+ private static final HashSet<String> TAGS = Sets.newHashSet("t1");
+
+ @BeforeAll
+ public static void setupServers() throws Exception {
+ CoordinatorConf coordinatorConf = getCoordinatorConf();
+ coordinatorConf.setLong(CoordinatorConf.COORDINATOR_APP_EXPIRED, 2000);
+ coordinatorConf.setInteger(CoordinatorConf.COORDINATOR_SHUFFLE_NODES_MAX,
SHUFFLE_NODES_MAX);
+ createCoordinatorServer(coordinatorConf);
+
+ for (int i = 0; i < SERVER_NUM; i++){
+ ShuffleServerConf shuffleServerConf = getShuffleServerConf();
+ File tmpDir = Files.createTempDir();
+ File dataDir1 = new File(tmpDir, "data1");
+ String basePath = dataDir1.getAbsolutePath();
+ shuffleServerConf.set(ShuffleServerConf.RSS_STORAGE_TYPE,
StorageType.MEMORY_LOCALFILE_HDFS.name());
+ shuffleServerConf.set(ShuffleServerConf.RSS_STORAGE_BASE_PATH, basePath);
+ shuffleServerConf.set(RssBaseConf.RPC_METRICS_ENABLED, true);
+
shuffleServerConf.set(ShuffleServerConf.SERVER_APP_EXPIRED_WITHOUT_HEARTBEAT,
2000L);
+ shuffleServerConf.set(ShuffleServerConf.SERVER_PRE_ALLOCATION_EXPIRED,
5000L);
+ shuffleServerConf.setInteger(RssBaseConf.RPC_SERVER_PORT, 18001 + i);
+ shuffleServerConf.setInteger(RssBaseConf.JETTY_HTTP_PORT, 19010 + i);
+ shuffleServerConf.set(ShuffleServerConf.TAGS, new ArrayList<>(TAGS));
+ createShuffleServer(shuffleServerConf);
+ }
+ startServers();
+
+ Thread.sleep(1000 * 5);
+ }
+
+ @Test
+ public void testAssignmentServerNodesNumber() throws Exception {
+ ShuffleWriteClientImpl shuffleWriteClient = new
ShuffleWriteClientImpl(ClientType.GRPC.name(), 3, 1000, 1,
+ 1, 1, 1, true, 1, 1);
+ shuffleWriteClient.registerCoordinators(COORDINATOR_QUORUM);
+
+ /**
+ * case1: user specify the illegal shuffle node num(<0)
+ * it will use the default shuffle nodes num when having enough servers.
+ */
+ ShuffleAssignmentsInfo info =
shuffleWriteClient.getShuffleAssignments("app1", 0, 10, 1, TAGS, -1);
+ assertEquals(SHUFFLE_NODES_MAX,
info.getServerToPartitionRanges().keySet().size());
+
+ /**
+ * case2: user specify the illegal shuffle node num(==0)
+ * it will use the default shuffle nodes num when having enough servers.
+ */
+ info = shuffleWriteClient.getShuffleAssignments("app1", 0, 10, 1, TAGS, 0);
+ assertEquals(SHUFFLE_NODES_MAX,
info.getServerToPartitionRanges().keySet().size());
+
+ /**
+ * case3: user specify the illegal shuffle node num(>default max
limitation)
+ * it will use the default shuffle nodes num when having enough servers
+ */
+ info = shuffleWriteClient.getShuffleAssignments("app1", 0, 10, 1, TAGS,
SERVER_NUM + 10);
+ assertEquals(SHUFFLE_NODES_MAX,
info.getServerToPartitionRanges().keySet().size());
+
+ /**
+ * case4: user specify the legal shuffle node num,
+ * it will use the customized shuffle nodes num when having enough servers
+ */
+ info = shuffleWriteClient.getShuffleAssignments("app1", 0, 10, 1, TAGS,
SERVER_NUM - 1);
+ assertEquals(SHUFFLE_NODES_MAX - 1,
info.getServerToPartitionRanges().keySet().size());
+ }
+}
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java
index 9ab84d4..4ca0c20 100644
---
a/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java
@@ -153,7 +153,7 @@ public class AssignmentWithTagsTest extends
CoordinatorTestBase {
// Case1 : only set the single default shuffle version tag
ShuffleAssignmentsInfo assignmentsInfo =
shuffleWriteClient.getShuffleAssignments("app-1",
- 1, 1, 1,
Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION));
+ 1, 1, 1,
Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION), 1);
List<Integer> assignedServerPorts = assignmentsInfo
.getPartitionToServers()
@@ -168,7 +168,7 @@ public class AssignmentWithTagsTest extends
CoordinatorTestBase {
// Case2: Set the single non-exist shuffle server tag
try {
assignmentsInfo = shuffleWriteClient.getShuffleAssignments("app-2",
- 1, 1, 1, Sets.newHashSet("non-exist"));
+ 1, 1, 1, Sets.newHashSet("non-exist"), 1);
fail();
} catch (Exception e) {
assertTrue(e.getMessage().startsWith("Error happened when
getShuffleAssignments with"));
@@ -176,7 +176,7 @@ public class AssignmentWithTagsTest extends
CoordinatorTestBase {
// Case3: Set the single fixed tag
assignmentsInfo = shuffleWriteClient.getShuffleAssignments("app-3",
- 1, 1, 1, Sets.newHashSet("fixed"));
+ 1, 1, 1, Sets.newHashSet("fixed"), 1);
assignedServerPorts = assignmentsInfo
.getPartitionToServers()
.values()
@@ -189,7 +189,7 @@ public class AssignmentWithTagsTest extends
CoordinatorTestBase {
// case4: Set the multiple tags if exists
assignmentsInfo = shuffleWriteClient.getShuffleAssignments("app-4",
- 1, 1, 1, Sets.newHashSet("fixed",
Constants.SHUFFLE_SERVER_VERSION));
+ 1, 1, 1, Sets.newHashSet("fixed",
Constants.SHUFFLE_SERVER_VERSION), 1);
assignedServerPorts = assignmentsInfo
.getPartitionToServers()
.values()
@@ -203,7 +203,7 @@ public class AssignmentWithTagsTest extends
CoordinatorTestBase {
// case5: Set the multiple tags if non-exist
try {
assignmentsInfo = shuffleWriteClient.getShuffleAssignments("app-5",
- 1, 1, 1, Sets.newHashSet("fixed", "elastic",
Constants.SHUFFLE_SERVER_VERSION));
+ 1, 1, 1, Sets.newHashSet("fixed", "elastic",
Constants.SHUFFLE_SERVER_VERSION), 1);
fail();
} catch (Exception e) {
assertTrue(e.getMessage().startsWith("Error happened when
getShuffleAssignments with"));
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcClient.java
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcClient.java
index dc1fa47..53e0922 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcClient.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcClient.java
@@ -153,7 +153,13 @@ public class CoordinatorGrpcClient extends GrpcClient
implements CoordinatorClie
}
public RssProtos.GetShuffleAssignmentsResponse doGetShuffleAssignments(
- String appId, int shuffleId, int numMaps, int partitionNumPerRange, int
dataReplica, Set<String> requiredTags) {
+ String appId,
+ int shuffleId,
+ int numMaps,
+ int partitionNumPerRange,
+ int dataReplica,
+ Set<String> requiredTags,
+ int assignmentShuffleServerNumber) {
RssProtos.GetShuffleServerRequest getServerRequest =
RssProtos.GetShuffleServerRequest.newBuilder()
.setApplicationId(appId)
@@ -162,6 +168,7 @@ public class CoordinatorGrpcClient extends GrpcClient
implements CoordinatorClie
.setPartitionNumPerRange(partitionNumPerRange)
.setDataReplica(dataReplica)
.addAllRequireTags(requiredTags)
+ .setAssignmentShuffleServerNumber(assignmentShuffleServerNumber)
.build();
return blockingStub.getShuffleAssignments(getServerRequest);
@@ -221,7 +228,8 @@ public class CoordinatorGrpcClient extends GrpcClient
implements CoordinatorClie
request.getPartitionNum(),
request.getPartitionNumPerRange(),
request.getDataReplica(),
- request.getRequiredTags());
+ request.getRequiredTags(),
+ request.getAssignmentShuffleServerNumber());
RssGetShuffleAssignmentsResponse response;
StatusCode statusCode = rpcResponse.getStatus();
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java
index acf0e3d..d0971cb 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java
@@ -19,6 +19,8 @@ package org.apache.uniffle.client.request;
import java.util.Set;
+import com.google.common.annotations.VisibleForTesting;
+
public class RssGetShuffleAssignmentsRequest {
private String appId;
@@ -27,15 +29,23 @@ public class RssGetShuffleAssignmentsRequest {
private int partitionNumPerRange;
private int dataReplica;
private Set<String> requiredTags;
+ private int assignmentShuffleServerNumber;
+ @VisibleForTesting
public RssGetShuffleAssignmentsRequest(String appId, int shuffleId, int
partitionNum,
int partitionNumPerRange, int dataReplica, Set<String> requiredTags) {
+ this(appId, shuffleId, partitionNum, partitionNumPerRange, dataReplica,
requiredTags, -1);
+ }
+
+ public RssGetShuffleAssignmentsRequest(String appId, int shuffleId, int
partitionNum,
+ int partitionNumPerRange, int dataReplica, Set<String> requiredTags, int
assignmentShuffleServerNumber) {
this.appId = appId;
this.shuffleId = shuffleId;
this.partitionNum = partitionNum;
this.partitionNumPerRange = partitionNumPerRange;
this.dataReplica = dataReplica;
this.requiredTags = requiredTags;
+ this.assignmentShuffleServerNumber = assignmentShuffleServerNumber;
}
public String getAppId() {
@@ -61,4 +71,8 @@ public class RssGetShuffleAssignmentsRequest {
public Set<String> getRequiredTags() {
return requiredTags;
}
+
+ public int getAssignmentShuffleServerNumber() {
+ return assignmentShuffleServerNumber;
+ }
}
diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto
index 647430d..d4d979c 100644
--- a/proto/src/main/proto/Rss.proto
+++ b/proto/src/main/proto/Rss.proto
@@ -295,6 +295,7 @@ message GetShuffleServerRequest {
int32 partitionNumPerRange = 7;
int32 dataReplica = 8;
repeated string requireTags = 9;
+ int32 assignmentShuffleServerNumber = 10;
}
message PartitionRangeAssignment {