This is an automated email from the ASF dual-hosted git repository.
roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 64c4a68ed [#1091] refactor: Refactor the writer code with builder
pattern (#1228)
64c4a68ed is described below
commit 64c4a68ed5bb86a24fe09e67e9fc14810444bdff
Author: summaryzb <[email protected]>
AuthorDate: Mon Oct 9 20:51:53 2023 -0500
[#1091] refactor: Refactor the writer code with builder pattern (#1228)
### What changes were proposed in this pull request?
As the title
### Why are the changes needed?
https://github.com/apache/incubator-uniffle/issues/1091
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
unit test
---
.../org/apache/hadoop/mapreduce/RssMRUtils.java | 23 +-
.../spark/shuffle/writer/DataPusherTest.java | 42 ++--
.../apache/spark/shuffle/RssShuffleManager.java | 27 ++-
.../apache/spark/shuffle/RssShuffleManager.java | 54 +++--
.../java/org/apache/tez/common/RssTezUtils.java | 21 +-
.../apache/tez/dag/app/RssDAGAppMasterTest.java | 44 ++--
.../client/factory/ShuffleClientFactory.java | 266 ++++++++++++---------
.../client/impl/ShuffleWriteClientImpl.java | 81 ++-----
.../client/impl/ShuffleWriteClientImplTest.java | 62 ++++-
.../uniffle/test/AssignmentWithTagsTest.java | 16 +-
.../uniffle/test/CoordinatorAssignmentTest.java | 61 ++++-
.../java/org/apache/uniffle/test/QuorumTest.java | 54 ++---
.../apache/uniffle/test/ShuffleServerGrpcTest.java | 15 +-
.../uniffle/test/ShuffleWithRssClientTest.java | 16 +-
14 files changed, 442 insertions(+), 340 deletions(-)
diff --git
a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java
b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java
index 67ca72b54..5be31d305 100644
--- a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java
+++ b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java
@@ -117,17 +117,18 @@ public class RssMRUtils {
ShuffleWriteClient client =
ShuffleClientFactory.getInstance()
.createShuffleWriteClient(
- clientType,
- retryMax,
- retryIntervalMax,
- heartBeatThreadNum,
- replica,
- replicaWrite,
- replicaRead,
- replicaSkipEnabled,
- dataTransferPoolSize,
- dataCommitPoolSize,
- RssMRConfig.toRssConf(jobConf));
+ ShuffleClientFactory.newWriteBuilder()
+ .clientType(clientType)
+ .retryMax(retryMax)
+ .retryIntervalMax(retryIntervalMax)
+ .heartBeatThreadNum(heartBeatThreadNum)
+ .replica(replica)
+ .replicaWrite(replicaWrite)
+ .replicaRead(replicaRead)
+ .replicaSkipEnabled(replicaSkipEnabled)
+ .dataTransferPoolSize(dataTransferPoolSize)
+ .dataCommitPoolSize(dataCommitPoolSize)
+ .rssConf(RssMRConfig.toRssConf(jobConf)));
return client;
}
diff --git
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
index 8b1ce0082..979f92822 100644
---
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
+++
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
@@ -30,6 +30,7 @@ import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.junit.jupiter.api.Test;
+import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.impl.ShuffleWriteClientImpl;
import org.apache.uniffle.client.response.SendShuffleDataResult;
import org.apache.uniffle.common.ShuffleBlockInfo;
@@ -43,35 +44,20 @@ public class DataPusherTest {
private SendShuffleDataResult fakedShuffleDataResult;
FakedShuffleWriteClient() {
- this("GRPC", 1, 1, 10, 1, 1, 1, false, 1, 1, 1, 1);
- }
-
- private FakedShuffleWriteClient(
- String clientType,
- int retryMax,
- long retryIntervalMax,
- int heartBeatThreadNum,
- int replica,
- int replicaWrite,
- int replicaRead,
- boolean replicaSkipEnabled,
- int dataTransferPoolSize,
- int dataCommitPoolSize,
- int unregisterThreadPoolSize,
- int unregisterRequestTimeSec) {
super(
- clientType,
- retryMax,
- retryIntervalMax,
- heartBeatThreadNum,
- replica,
- replicaWrite,
- replicaRead,
- replicaSkipEnabled,
- dataTransferPoolSize,
- dataCommitPoolSize,
- unregisterThreadPoolSize,
- unregisterRequestTimeSec);
+ ShuffleClientFactory.newWriteBuilder()
+ .clientType("GRPC")
+ .retryMax(1)
+ .retryIntervalMax(1)
+ .heartBeatThreadNum(10)
+ .replica(1)
+ .replicaWrite(1)
+ .replicaRead(1)
+ .replicaSkipEnabled(true)
+ .dataTransferPoolSize(1)
+ .dataCommitPoolSize(1)
+ .unregisterThreadPoolSize(1)
+ .unregisterRequestTimeSec(1));
}
@Override
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index e38cc570c..3943b5c9a 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -150,19 +150,20 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
this.shuffleWriteClient =
ShuffleClientFactory.getInstance()
.createShuffleWriteClient(
- clientType,
- retryMax,
- retryIntervalMax,
- heartBeatThreadNum,
- dataReplica,
- dataReplicaWrite,
- dataReplicaRead,
- dataReplicaSkipEnabled,
- dataTransferPoolSize,
- dataCommitPoolSize,
- unregisterThreadPoolSize,
- unregisterRequestTimeoutSec,
- rssConf);
+ ShuffleClientFactory.newWriteBuilder()
+ .clientType(clientType)
+ .retryMax(retryMax)
+ .retryIntervalMax(retryIntervalMax)
+ .heartBeatThreadNum(heartBeatThreadNum)
+ .replica(dataReplica)
+ .replicaWrite(dataReplicaWrite)
+ .replicaRead(dataReplicaRead)
+ .replicaSkipEnabled(dataReplicaSkipEnabled)
+ .dataTransferPoolSize(dataTransferPoolSize)
+ .dataCommitPoolSize(dataCommitPoolSize)
+ .unregisterThreadPoolSize(unregisterThreadPoolSize)
+ .unregisterRequestTimeSec(unregisterRequestTimeoutSec)
+ .rssConf(rssConf));
registerCoordinator();
// fetch client conf and apply them if necessary and disable ESS
if (isDriver && dynamicConfEnabled) {
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 4605ab040..2cb7d7e8f 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -168,19 +168,20 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
shuffleWriteClient =
ShuffleClientFactory.getInstance()
.createShuffleWriteClient(
- clientType,
- retryMax,
- retryIntervalMax,
- heartBeatThreadNum,
- dataReplica,
- dataReplicaWrite,
- dataReplicaRead,
- dataReplicaSkipEnabled,
- dataTransferPoolSize,
- dataCommitPoolSize,
- unregisterThreadPoolSize,
- unregisterRequestTimeoutSec,
- rssConf);
+ ShuffleClientFactory.newWriteBuilder()
+ .clientType(clientType)
+ .retryMax(retryMax)
+ .retryIntervalMax(retryIntervalMax)
+ .heartBeatThreadNum(heartBeatThreadNum)
+ .replica(dataReplica)
+ .replicaWrite(dataReplicaWrite)
+ .replicaRead(dataReplicaRead)
+ .replicaSkipEnabled(dataReplicaSkipEnabled)
+ .dataTransferPoolSize(dataTransferPoolSize)
+ .dataCommitPoolSize(dataCommitPoolSize)
+ .unregisterThreadPoolSize(unregisterThreadPoolSize)
+ .unregisterRequestTimeSec(unregisterRequestTimeoutSec)
+ .rssConf(rssConf));
registerCoordinator();
// fetch client conf and apply them if necessary and disable ESS
if (isDriver && dynamicConfEnabled) {
@@ -302,19 +303,20 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
shuffleWriteClient =
ShuffleClientFactory.getInstance()
.createShuffleWriteClient(
- clientType,
- retryMax,
- retryIntervalMax,
- heartBeatThreadNum,
- dataReplica,
- dataReplicaWrite,
- dataReplicaRead,
- dataReplicaSkipEnabled,
- dataTransferPoolSize,
- dataCommitPoolSize,
- unregisterThreadPoolSize,
- unregisterRequestTimeoutSec,
- RssSparkConfig.toRssConf(sparkConf));
+ ShuffleClientFactory.newWriteBuilder()
+ .clientType(clientType)
+ .retryMax(retryMax)
+ .retryIntervalMax(retryIntervalMax)
+ .heartBeatThreadNum(heartBeatThreadNum)
+ .replica(dataReplica)
+ .replicaWrite(dataReplicaWrite)
+ .replicaRead(dataReplicaRead)
+ .replicaSkipEnabled(dataReplicaSkipEnabled)
+ .dataTransferPoolSize(dataTransferPoolSize)
+ .dataCommitPoolSize(dataCommitPoolSize)
+ .unregisterThreadPoolSize(unregisterThreadPoolSize)
+ .unregisterRequestTimeSec(unregisterRequestTimeoutSec)
+ .rssConf(RssSparkConfig.toRssConf(sparkConf)));
this.taskToSuccessBlockIds = taskToSuccessBlockIds;
this.taskToFailedBlockIds = taskToFailedBlockIds;
this.heartBeatScheduledExecutorService = null;
diff --git a/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
b/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
index 4b9f055a4..e497995f1 100644
--- a/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
+++ b/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
@@ -111,16 +111,17 @@ public class RssTezUtils {
ShuffleWriteClient client =
ShuffleClientFactory.getInstance()
.createShuffleWriteClient(
- clientType,
- retryMax,
- retryIntervalMax,
- heartBeatThreadNum,
- replica,
- replicaWrite,
- replicaRead,
- replicaSkipEnabled,
- dataTransferPoolSize,
- dataCommitPoolSize);
+ ShuffleClientFactory.newWriteBuilder()
+ .clientType(clientType)
+ .retryMax(retryMax)
+ .retryIntervalMax(retryIntervalMax)
+ .heartBeatThreadNum(heartBeatThreadNum)
+ .replica(replica)
+ .replicaWrite(replicaWrite)
+ .replicaRead(replicaRead)
+ .replicaSkipEnabled(replicaSkipEnabled)
+ .dataTransferPoolSize(dataTransferPoolSize)
+ .dataCommitPoolSize(dataCommitPoolSize));
return client;
}
diff --git
a/client-tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterTest.java
b/client-tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterTest.java
index 910425588..ba1921b60 100644
--- a/client-tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterTest.java
+++ b/client-tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterTest.java
@@ -87,6 +87,7 @@ import
org.apache.tez.runtime.library.processor.SimpleProcessor;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
+import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.impl.ShuffleWriteClientImpl;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.exception.RssException;
@@ -497,36 +498,21 @@ public class RssDAGAppMasterTest {
private int mode;
FakedShuffleWriteClient(int mode) {
- this("GRPC", 1, 1, 10, 1, 1, 1, false, 1, 1, 1, 1);
- this.mode = mode;
- }
-
- private FakedShuffleWriteClient(
- String clientType,
- int retryMax,
- long retryIntervalMax,
- int heartBeatThreadNum,
- int replica,
- int replicaWrite,
- int replicaRead,
- boolean replicaSkipEnabled,
- int dataTransferPoolSize,
- int dataCommitPoolSize,
- int unregisterThreadPoolSize,
- int unregisterRequestTimeSec) {
super(
- clientType,
- retryMax,
- retryIntervalMax,
- heartBeatThreadNum,
- replica,
- replicaWrite,
- replicaRead,
- replicaSkipEnabled,
- dataTransferPoolSize,
- dataCommitPoolSize,
- unregisterThreadPoolSize,
- unregisterRequestTimeSec);
+ ShuffleClientFactory.newWriteBuilder()
+ .clientType("GRPC")
+ .retryMax(1)
+ .retryIntervalMax(1)
+ .heartBeatThreadNum(10)
+ .replica(1)
+ .replicaWrite(1)
+ .replicaRead(1)
+ .replicaSkipEnabled(true)
+ .dataTransferPoolSize(1)
+ .dataCommitPoolSize(1)
+ .unregisterThreadPoolSize(1)
+ .unregisterRequestTimeSec(1));
+ this.mode = mode;
}
@Override
diff --git
a/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
b/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
index fcf4285ae..cdabf0533 100644
---
a/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
+++
b/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
@@ -34,124 +34,11 @@ public class ShuffleClientFactory {
return INSTANCE;
}
- /** Only for MR engine, which won't used to unregister to remote
shuffle-servers */
- public ShuffleWriteClient createShuffleWriteClient(
- String clientType,
- int retryMax,
- long retryIntervalMax,
- int heartBeatThreadNum,
- int replica,
- int replicaWrite,
- int replicaRead,
- boolean replicaSkipEnabled,
- int dataTransferPoolSize,
- int dataCommitPoolSize) {
- return createShuffleWriteClient(
- clientType,
- retryMax,
- retryIntervalMax,
- heartBeatThreadNum,
- replica,
- replicaWrite,
- replicaRead,
- replicaSkipEnabled,
- dataTransferPoolSize,
- dataCommitPoolSize,
- 10,
- 10,
- new RssConf());
- }
-
- public ShuffleWriteClient createShuffleWriteClient(
- String clientType,
- int retryMax,
- long retryIntervalMax,
- int heartBeatThreadNum,
- int replica,
- int replicaWrite,
- int replicaRead,
- boolean replicaSkipEnabled,
- int dataTransferPoolSize,
- int dataCommitPoolSize,
- RssConf rssConf) {
- return createShuffleWriteClient(
- clientType,
- retryMax,
- retryIntervalMax,
- heartBeatThreadNum,
- replica,
- replicaWrite,
- replicaRead,
- replicaSkipEnabled,
- dataTransferPoolSize,
- dataCommitPoolSize,
- 10,
- 10,
- rssConf);
- }
-
- public ShuffleWriteClient createShuffleWriteClient(
- String clientType,
- int retryMax,
- long retryIntervalMax,
- int heartBeatThreadNum,
- int replica,
- int replicaWrite,
- int replicaRead,
- boolean replicaSkipEnabled,
- int dataTransferPoolSize,
- int dataCommitPoolSize,
- int unregisterThreadPoolSize,
- int unregisterRequestTimeoutSec) {
- return createShuffleWriteClient(
- clientType,
- retryMax,
- retryIntervalMax,
- heartBeatThreadNum,
- replica,
- replicaWrite,
- replicaRead,
- replicaSkipEnabled,
- dataTransferPoolSize,
- dataCommitPoolSize,
- unregisterThreadPoolSize,
- unregisterRequestTimeoutSec,
- new RssConf());
- }
-
- public ShuffleWriteClient createShuffleWriteClient(
- String clientType,
- int retryMax,
- long retryIntervalMax,
- int heartBeatThreadNum,
- int replica,
- int replicaWrite,
- int replicaRead,
- boolean replicaSkipEnabled,
- int dataTransferPoolSize,
- int dataCommitPoolSize,
- int unregisterThreadPoolSize,
- int unregisterRequestTimeoutSec,
- RssConf rssConf) {
- // If replica > replicaWrite, blocks maybe be sent for 2 rounds.
- // We need retry less times in this case for let the first round fail fast.
- if (replicaSkipEnabled && replica > replicaWrite) {
- retryMax = retryMax / 2;
- }
- return new ShuffleWriteClientImpl(
- clientType,
- retryMax,
- retryIntervalMax,
- heartBeatThreadNum,
- replica,
- replicaWrite,
- replicaRead,
- replicaSkipEnabled,
- dataTransferPoolSize,
- dataCommitPoolSize,
- unregisterThreadPoolSize,
- unregisterRequestTimeoutSec,
- rssConf);
+ public ShuffleWriteClient createShuffleWriteClient(WriteClientBuilder
builder) {
+ if (builder.isReplicaSkipEnabled() && builder.getReplica() >
builder.getReplicaWrite()) {
+ builder.retryMax(builder.getRetryMax() / 2);
+ }
+ return builder.build();
}
public ShuffleReadClient
createShuffleReadClient(CreateShuffleReadClientRequest request) {
@@ -171,4 +58,147 @@ public class ShuffleClientFactory {
request.isExpectedTaskIdsBitmapFilterEnable(),
request.getRssConf());
}
+
+ public static class WriteClientBuilder {
+ private WriteClientBuilder() {}
+
+ private String clientType;
+ private int retryMax;
+ private long retryIntervalMax;
+ private int heartBeatThreadNum;
+ private int replica;
+ private int replicaWrite;
+ private int replicaRead;
+ private boolean replicaSkipEnabled;
+ private int dataTransferPoolSize;
+ private int dataCommitPoolSize;
+ private int unregisterThreadPoolSize;
+ private int unregisterRequestTimeSec;
+ private RssConf rssConf;
+
+ public String getClientType() {
+ return clientType;
+ }
+
+ public int getRetryMax() {
+ return retryMax;
+ }
+
+ public long getRetryIntervalMax() {
+ return retryIntervalMax;
+ }
+
+ public int getHeartBeatThreadNum() {
+ return heartBeatThreadNum;
+ }
+
+ public int getReplica() {
+ return replica;
+ }
+
+ public int getReplicaWrite() {
+ return replicaWrite;
+ }
+
+ public int getReplicaRead() {
+ return replicaRead;
+ }
+
+ public boolean isReplicaSkipEnabled() {
+ return replicaSkipEnabled;
+ }
+
+ public int getDataTransferPoolSize() {
+ return dataTransferPoolSize;
+ }
+
+ public int getDataCommitPoolSize() {
+ return dataCommitPoolSize;
+ }
+
+ public int getUnregisterThreadPoolSize() {
+ return unregisterThreadPoolSize;
+ }
+
+ public int getUnregisterRequestTimeSec() {
+ return unregisterRequestTimeSec;
+ }
+
+ public RssConf getRssConf() {
+ return rssConf;
+ }
+
+ public WriteClientBuilder clientType(String clientType) {
+ this.clientType = clientType;
+ return this;
+ }
+
+ public WriteClientBuilder retryMax(int retryMax) {
+ this.retryMax = retryMax;
+ return this;
+ }
+
+ public WriteClientBuilder retryIntervalMax(long retryIntervalMax) {
+ this.retryIntervalMax = retryIntervalMax;
+ return this;
+ }
+
+ public WriteClientBuilder heartBeatThreadNum(int heartBeatThreadNum) {
+ this.heartBeatThreadNum = heartBeatThreadNum;
+ return this;
+ }
+
+ public WriteClientBuilder replica(int replica) {
+ this.replica = replica;
+ return this;
+ }
+
+ public WriteClientBuilder replicaWrite(int replicaWrite) {
+ this.replicaWrite = replicaWrite;
+ return this;
+ }
+
+ public WriteClientBuilder replicaRead(int replicaRead) {
+ this.replicaRead = replicaRead;
+ return this;
+ }
+
+ public WriteClientBuilder replicaSkipEnabled(boolean replicaSkipEnabled) {
+ this.replicaSkipEnabled = replicaSkipEnabled;
+ return this;
+ }
+
+ public WriteClientBuilder dataTransferPoolSize(int dataTransferPoolSize) {
+ this.dataTransferPoolSize = dataTransferPoolSize;
+ return this;
+ }
+
+ public WriteClientBuilder dataCommitPoolSize(int dataCommitPoolSize) {
+ this.dataCommitPoolSize = dataCommitPoolSize;
+ return this;
+ }
+
+ public WriteClientBuilder unregisterThreadPoolSize(int
unregisterThreadPoolSize) {
+ this.unregisterThreadPoolSize = unregisterThreadPoolSize;
+ return this;
+ }
+
+ public WriteClientBuilder unregisterRequestTimeSec(int
unregisterRequestTimeSec) {
+ this.unregisterRequestTimeSec = unregisterRequestTimeSec;
+ return this;
+ }
+
+ public WriteClientBuilder rssConf(RssConf rssConf) {
+ this.rssConf = rssConf;
+ return this;
+ }
+
+ public ShuffleWriteClientImpl build() {
+ return new ShuffleWriteClientImpl(this);
+ }
+ }
+
+ public static WriteClientBuilder newWriteBuilder() {
+ return new WriteClientBuilder();
+ }
}
diff --git
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
index 0c0b278c8..5c9932a50 100644
---
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
+++
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
@@ -48,6 +48,7 @@ import org.apache.uniffle.client.api.CoordinatorClient;
import org.apache.uniffle.client.api.ShuffleServerClient;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.CoordinatorClientFactory;
+import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.factory.ShuffleServerClientFactory;
import org.apache.uniffle.client.request.RssAppHeartBeatRequest;
import org.apache.uniffle.client.request.RssApplicationInfoRequest;
@@ -115,67 +116,35 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
private Set<ShuffleServerInfo> defectiveServers;
private RssConf rssConf;
- public ShuffleWriteClientImpl(
- String clientType,
- int retryMax,
- long retryIntervalMax,
- int heartBeatThreadNum,
- int replica,
- int replicaWrite,
- int replicaRead,
- boolean replicaSkipEnabled,
- int dataTransferPoolSize,
- int dataCommitPoolSize,
- int unregisterThreadPoolSize,
- int unregisterRequestTimeSec) {
- this(
- clientType,
- retryMax,
- retryIntervalMax,
- heartBeatThreadNum,
- replica,
- replicaWrite,
- replicaRead,
- replicaSkipEnabled,
- dataTransferPoolSize,
- dataCommitPoolSize,
- unregisterThreadPoolSize,
- unregisterRequestTimeSec,
- new RssConf());
- }
-
- public ShuffleWriteClientImpl(
- String clientType,
- int retryMax,
- long retryIntervalMax,
- int heartBeatThreadNum,
- int replica,
- int replicaWrite,
- int replicaRead,
- boolean replicaSkipEnabled,
- int dataTransferPoolSize,
- int dataCommitPoolSize,
- int unregisterThreadPoolSize,
- int unregisterRequestTimeSec,
- RssConf rssConf) {
- this.clientType = clientType;
- this.retryMax = retryMax;
- this.retryIntervalMax = retryIntervalMax;
+ public ShuffleWriteClientImpl(ShuffleClientFactory.WriteClientBuilder
builder) {
+ // set default value
+ if (builder.getRssConf() != null) {
+ builder.rssConf(new RssConf());
+ }
+ if (builder.getUnregisterThreadPoolSize() == 0) {
+ builder.unregisterThreadPoolSize(10);
+ }
+ if (builder.getUnregisterRequestTimeSec() == 0) {
+ builder.unregisterRequestTimeSec(10);
+ }
+ this.clientType = builder.getClientType();
+ this.retryMax = builder.getRetryMax();
+ this.retryIntervalMax = builder.getRetryIntervalMax();
this.coordinatorClientFactory = new
CoordinatorClientFactory(ClientType.valueOf(clientType));
this.heartBeatExecutorService =
- ThreadUtils.getDaemonFixedThreadPool(heartBeatThreadNum,
"client-heartbeat");
- this.replica = replica;
- this.replicaWrite = replicaWrite;
- this.replicaRead = replicaRead;
- this.replicaSkipEnabled = replicaSkipEnabled;
- this.dataTransferPool = Executors.newFixedThreadPool(dataTransferPoolSize);
- this.dataCommitPoolSize = dataCommitPoolSize;
- this.unregisterThreadPoolSize = unregisterThreadPoolSize;
- this.unregisterRequestTimeSec = unregisterRequestTimeSec;
+ ThreadUtils.getDaemonFixedThreadPool(builder.getHeartBeatThreadNum(),
"client-heartbeat");
+ this.replica = builder.getReplica();
+ this.replicaWrite = builder.getReplicaWrite();
+ this.replicaRead = builder.getReplicaRead();
+ this.replicaSkipEnabled = builder.isReplicaSkipEnabled();
+ this.dataTransferPool =
Executors.newFixedThreadPool(builder.getDataTransferPoolSize());
+ this.dataCommitPoolSize = builder.getDataCommitPoolSize();
+ this.unregisterThreadPoolSize = builder.getUnregisterThreadPoolSize();
+ this.unregisterRequestTimeSec = builder.getUnregisterRequestTimeSec();
if (replica > 1) {
defectiveServers = Sets.newConcurrentHashSet();
}
- this.rssConf = rssConf;
+ this.rssConf = builder.getRssConf();
}
private boolean sendShuffleDataAsync(
diff --git
a/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java
b/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java
index ac4ed009c..d9e4f7ff3 100644
---
a/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java
+++
b/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java
@@ -29,8 +29,10 @@ import org.mockito.Mockito;
import org.mockito.stubbing.Answer;
import org.apache.uniffle.client.api.ShuffleServerClient;
+import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.response.RssSendShuffleDataResponse;
import org.apache.uniffle.client.response.SendShuffleDataResult;
+import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.rpc.StatusCode;
@@ -47,7 +49,20 @@ public class ShuffleWriteClientImplTest {
@Test
public void testAbandonEventWhenTaskFailed() {
ShuffleWriteClientImpl shuffleWriteClient =
- new ShuffleWriteClientImpl("GRPC", 3, 2000, 4, 1, 1, 1, true, 1, 1,
10, 10);
+ ShuffleClientFactory.newWriteBuilder()
+ .clientType(ClientType.GRPC.name())
+ .retryMax(3)
+ .retryIntervalMax(2000)
+ .heartBeatThreadNum(4)
+ .replica(1)
+ .replicaWrite(1)
+ .replicaRead(1)
+ .replicaSkipEnabled(true)
+ .dataTransferPoolSize(1)
+ .dataCommitPoolSize(1)
+ .unregisterThreadPoolSize(10)
+ .unregisterRequestTimeSec(10)
+ .build();
ShuffleServerClient mockShuffleServerClient =
mock(ShuffleServerClient.class);
ShuffleWriteClientImpl spyClient = Mockito.spy(shuffleWriteClient);
doReturn(mockShuffleServerClient).when(spyClient).getShuffleServerClient(any());
@@ -80,7 +95,20 @@ public class ShuffleWriteClientImplTest {
@Test
public void testSendData() {
ShuffleWriteClientImpl shuffleWriteClient =
- new ShuffleWriteClientImpl("GRPC", 3, 2000, 4, 1, 1, 1, true, 1, 1,
10, 10);
+ ShuffleClientFactory.newWriteBuilder()
+ .clientType(ClientType.GRPC.name())
+ .retryMax(3)
+ .retryIntervalMax(2000)
+ .heartBeatThreadNum(4)
+ .replica(1)
+ .replicaWrite(1)
+ .replicaRead(1)
+ .replicaSkipEnabled(true)
+ .dataTransferPoolSize(1)
+ .dataCommitPoolSize(1)
+ .unregisterThreadPoolSize(10)
+ .unregisterRequestTimeSec(10)
+ .build();
ShuffleServerClient mockShuffleServerClient =
mock(ShuffleServerClient.class);
ShuffleWriteClientImpl spyClient = Mockito.spy(shuffleWriteClient);
doReturn(mockShuffleServerClient).when(spyClient).getShuffleServerClient(any());
@@ -102,7 +130,20 @@ public class ShuffleWriteClientImplTest {
@Test
public void testRegisterAndUnRegisterShuffleServer() {
ShuffleWriteClientImpl shuffleWriteClient =
- new ShuffleWriteClientImpl("GRPC", 3, 2000, 4, 1, 1, 1, true, 1, 1,
10, 10);
+ ShuffleClientFactory.newWriteBuilder()
+ .clientType(ClientType.GRPC.name())
+ .retryMax(3)
+ .retryIntervalMax(2000)
+ .heartBeatThreadNum(4)
+ .replica(1)
+ .replicaWrite(1)
+ .replicaRead(1)
+ .replicaSkipEnabled(true)
+ .dataTransferPoolSize(1)
+ .dataCommitPoolSize(1)
+ .unregisterThreadPoolSize(10)
+ .unregisterRequestTimeSec(10)
+ .build();
String appId1 = "testRegisterAndUnRegisterShuffleServer-1";
String appId2 = "testRegisterAndUnRegisterShuffleServer-2";
ShuffleServerInfo server1 = new ShuffleServerInfo("host1-0", "host1", 0);
@@ -127,7 +168,20 @@ public class ShuffleWriteClientImplTest {
@Test
public void testSendDataWithDefectiveServers() {
ShuffleWriteClientImpl shuffleWriteClient =
- new ShuffleWriteClientImpl("GRPC", 3, 2000, 4, 3, 2, 2, true, 1, 1,
10, 10);
+ ShuffleClientFactory.newWriteBuilder()
+ .clientType(ClientType.GRPC.name())
+ .retryMax(3)
+ .retryIntervalMax(2000)
+ .heartBeatThreadNum(4)
+ .replica(3)
+ .replicaWrite(2)
+ .replicaRead(2)
+ .replicaSkipEnabled(true)
+ .dataTransferPoolSize(1)
+ .dataCommitPoolSize(1)
+ .unregisterThreadPoolSize(10)
+ .unregisterRequestTimeSec(10)
+ .build();
ShuffleServerClient mockShuffleServerClient =
mock(ShuffleServerClient.class);
ShuffleWriteClientImpl spyClient = Mockito.spy(shuffleWriteClient);
doReturn(mockShuffleServerClient).when(spyClient).getShuffleServerClient(any());
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java
index 52982adf3..5f76e5a9d 100644
---
a/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java
@@ -35,6 +35,7 @@ import org.junit.jupiter.api.io.TempDir;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.impl.ShuffleWriteClientImpl;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.ShuffleAssignmentsInfo;
@@ -151,7 +152,20 @@ public class AssignmentWithTagsTest extends
CoordinatorTestBase {
@Test
public void testTags() throws Exception {
ShuffleWriteClientImpl shuffleWriteClient =
- new ShuffleWriteClientImpl(ClientType.GRPC.name(), 3, 1000, 1, 1, 1,
1, true, 1, 1, 10, 10);
+ ShuffleClientFactory.newWriteBuilder()
+ .clientType(ClientType.GRPC.name())
+ .retryMax(3)
+ .retryIntervalMax(1000)
+ .heartBeatThreadNum(1)
+ .replica(1)
+ .replicaWrite(1)
+ .replicaRead(1)
+ .replicaSkipEnabled(true)
+ .dataTransferPoolSize(1)
+ .dataCommitPoolSize(1)
+ .unregisterThreadPoolSize(10)
+ .unregisterRequestTimeSec(10)
+ .build();
shuffleWriteClient.registerCoordinators(COORDINATOR_QUORUM);
// Case1 : only set the single default shuffle version tag
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorAssignmentTest.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorAssignmentTest.java
index 4b892ede8..8161db303 100644
---
a/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorAssignmentTest.java
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorAssignmentTest.java
@@ -36,6 +36,7 @@ import org.junit.jupiter.api.io.TempDir;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.impl.ShuffleWriteClientImpl;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.ShuffleAssignmentsInfo;
@@ -105,7 +106,20 @@ public class CoordinatorAssignmentTest extends
CoordinatorTestBase {
@Test
public void testSilentPeriod() throws Exception {
ShuffleWriteClientImpl shuffleWriteClient =
- new ShuffleWriteClientImpl(ClientType.GRPC.name(), 3, 1000, 1, 1, 1,
1, true, 1, 1, 10, 10);
+ ShuffleClientFactory.newWriteBuilder()
+ .clientType(ClientType.GRPC.name())
+ .retryMax(3)
+ .retryIntervalMax(1000)
+ .heartBeatThreadNum(1)
+ .replica(1)
+ .replicaWrite(1)
+ .replicaRead(1)
+ .replicaSkipEnabled(true)
+ .dataTransferPoolSize(1)
+ .dataCommitPoolSize(1)
+ .unregisterThreadPoolSize(10)
+ .unregisterRequestTimeSec(10)
+ .build();
shuffleWriteClient.registerCoordinators(QUORUM);
// Case1: Disable silent period
@@ -132,7 +146,20 @@ public class CoordinatorAssignmentTest extends
CoordinatorTestBase {
@Test
public void testAssignmentServerNodesNumber() throws Exception {
ShuffleWriteClientImpl shuffleWriteClient =
- new ShuffleWriteClientImpl(ClientType.GRPC.name(), 3, 1000, 1, 1, 1,
1, true, 1, 1, 10, 10);
+ ShuffleClientFactory.newWriteBuilder()
+ .clientType(ClientType.GRPC.name())
+ .retryMax(3)
+ .retryIntervalMax(1000)
+ .heartBeatThreadNum(1)
+ .replica(1)
+ .replicaWrite(1)
+ .replicaRead(1)
+ .replicaSkipEnabled(true)
+ .dataTransferPoolSize(1)
+ .dataCommitPoolSize(1)
+ .unregisterThreadPoolSize(10)
+ .unregisterRequestTimeSec(10)
+ .build();
shuffleWriteClient.registerCoordinators(COORDINATOR_QUORUM);
/**
@@ -174,7 +201,20 @@ public class CoordinatorAssignmentTest extends
CoordinatorTestBase {
.getString(ReconfigurableBase.RECONFIGURABLE_FILE_NAME, "");
new File(fileName).createNewFile();
ShuffleWriteClientImpl shuffleWriteClient =
- new ShuffleWriteClientImpl(ClientType.GRPC.name(), 3, 1000, 1, 1, 1,
1, true, 1, 1, 10, 10);
+ ShuffleClientFactory.newWriteBuilder()
+ .clientType(ClientType.GRPC.name())
+ .retryMax(3)
+ .retryIntervalMax(1000)
+ .heartBeatThreadNum(1)
+ .replica(1)
+ .replicaWrite(1)
+ .replicaRead(1)
+ .replicaSkipEnabled(true)
+ .dataTransferPoolSize(1)
+ .dataCommitPoolSize(1)
+ .unregisterThreadPoolSize(10)
+ .unregisterRequestTimeSec(10)
+ .build();
shuffleWriteClient.registerCoordinators(COORDINATOR_QUORUM);
ShuffleAssignmentsInfo info =
shuffleWriteClient.getShuffleAssignments("app1", 0, 10, 1, TAGS,
SERVER_NUM + 10, -1);
@@ -195,7 +235,20 @@ public class CoordinatorAssignmentTest extends
CoordinatorTestBase {
@Test
public void testGetReShuffleAssignments() {
ShuffleWriteClientImpl shuffleWriteClient =
- new ShuffleWriteClientImpl(ClientType.GRPC.name(), 3, 1000, 1, 1, 1,
1, true, 1, 1, 10, 10);
+ ShuffleClientFactory.newWriteBuilder()
+ .clientType(ClientType.GRPC.name())
+ .retryMax(3)
+ .retryIntervalMax(1000)
+ .heartBeatThreadNum(1)
+ .replica(1)
+ .replicaWrite(1)
+ .replicaRead(1)
+ .replicaSkipEnabled(true)
+ .dataTransferPoolSize(1)
+ .dataCommitPoolSize(1)
+ .unregisterThreadPoolSize(10)
+ .unregisterRequestTimeSec(10)
+ .build();
shuffleWriteClient.registerCoordinators(COORDINATOR_QUORUM);
Set<String> excludeServer = Sets.newConcurrentHashSet();
List<ShuffleServer> excludeShuffleServer =
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/QuorumTest.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/QuorumTest.java
index 324ef8da3..f653103fd 100644
---
a/integration-test/common/src/test/java/org/apache/uniffle/test/QuorumTest.java
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/QuorumTest.java
@@ -30,6 +30,7 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
+import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.factory.ShuffleServerClientFactory;
import org.apache.uniffle.client.impl.ShuffleReadClientImpl;
import org.apache.uniffle.client.impl.ShuffleWriteClientImpl;
@@ -300,32 +301,8 @@ public class QuorumTest extends ShuffleReadWriteBase {
}
static class MockedShuffleWriteClientImpl extends ShuffleWriteClientImpl {
- MockedShuffleWriteClientImpl(
- String clientType,
- int retryMax,
- long retryIntervalMax,
- int heartBeatThreadNum,
- int replica,
- int replicaWrite,
- int replicaRead,
- boolean replicaSkipEnabled,
- int dataTranferPoolSize,
- int dataCommitPoolSize,
- int unregisterThreadPoolSize,
- int unregisterRequestTimeSec) {
- super(
- clientType,
- retryMax,
- retryIntervalMax,
- heartBeatThreadNum,
- replica,
- replicaWrite,
- replicaRead,
- replicaSkipEnabled,
- dataTranferPoolSize,
- dataCommitPoolSize,
- unregisterThreadPoolSize,
- unregisterRequestTimeSec);
+ MockedShuffleWriteClientImpl(ShuffleClientFactory.WriteClientBuilder
builder) {
+ super(builder);
}
public SendShuffleDataResult sendShuffleData(
@@ -339,18 +316,19 @@ public class QuorumTest extends ShuffleReadWriteBase {
shuffleWriteClientImpl =
new MockedShuffleWriteClientImpl(
- ClientType.GRPC.name(),
- 3,
- 1000,
- 1,
- replica,
- replicaWrite,
- replicaRead,
- replicaSkip,
- 1,
- 1,
- 10,
- 10);
+ ShuffleClientFactory.newWriteBuilder()
+ .clientType(ClientType.GRPC.name())
+ .retryMax(3)
+ .retryIntervalMax(1000)
+ .heartBeatThreadNum(1)
+ .replica(replica)
+ .replicaWrite(replicaWrite)
+ .replicaRead(replicaRead)
+ .replicaSkipEnabled(replicaSkip)
+ .dataTransferPoolSize(1)
+ .dataCommitPoolSize(1)
+ .unregisterThreadPoolSize(10)
+ .unregisterRequestTimeSec(10));
List<ShuffleServerInfo> allServers =
Lists.newArrayList(
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
index ecc8ab999..7c5acdb17 100644
---
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
@@ -110,7 +110,20 @@ public class ShuffleServerGrpcTest extends
IntegrationTestBase {
public void clearResourceTest() throws Exception {
final ShuffleWriteClient shuffleWriteClient =
ShuffleClientFactory.getInstance()
- .createShuffleWriteClient("GRPC", 2, 10000L, 4, 1, 1, 1, true, 1,
1, 10, 10);
+ .createShuffleWriteClient(
+ ShuffleClientFactory.newWriteBuilder()
+ .clientType("GRPC")
+ .retryMax(2)
+ .retryIntervalMax(10000L)
+ .heartBeatThreadNum(4)
+ .replica(1)
+ .replicaWrite(1)
+ .replicaRead(1)
+ .replicaSkipEnabled(true)
+ .dataTransferPoolSize(1)
+ .dataCommitPoolSize(1)
+ .unregisterThreadPoolSize(10)
+ .unregisterRequestTimeSec(10));
shuffleWriteClient.registerCoordinators("127.0.0.1:19999");
shuffleWriteClient.registerShuffle(
new ShuffleServerInfo("127.0.0.1-20001", "127.0.0.1", 20001),
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java
index 12cb84c0d..2ad1e0492 100644
---
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java
@@ -32,6 +32,7 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
+import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.impl.ShuffleReadClientImpl;
import org.apache.uniffle.client.impl.ShuffleWriteClientImpl;
import org.apache.uniffle.client.response.SendShuffleDataResult;
@@ -94,7 +95,20 @@ public class ShuffleWithRssClientTest extends
ShuffleReadWriteBase {
@BeforeEach
public void createClient() {
shuffleWriteClientImpl =
- new ShuffleWriteClientImpl(ClientType.GRPC.name(), 3, 1000, 1, 1, 1,
1, true, 1, 1, 10, 10);
+ ShuffleClientFactory.newWriteBuilder()
+ .clientType(ClientType.GRPC.name())
+ .retryMax(3)
+ .retryIntervalMax(1000)
+ .heartBeatThreadNum(1)
+ .replica(1)
+ .replicaWrite(1)
+ .replicaRead(1)
+ .replicaSkipEnabled(true)
+ .dataTransferPoolSize(1)
+ .dataCommitPoolSize(1)
+ .unregisterThreadPoolSize(10)
+ .unregisterRequestTimeSec(10)
+ .build();
}
@AfterEach