This is an automated email from the ASF dual-hosted git repository.
xianjin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 457c86536 [#1608] feat: Introduce ExpiringClosableSupplier and
refactor ShuffleManagerClient creation (#1838)
457c86536 is described below
commit 457c865362e1dc573004b30c505287c253a6dba0
Author: xumanbu <[email protected]>
AuthorDate: Fri Jul 26 21:24:28 2024 +0800
[#1608] feat: Introduce ExpiringClosableSupplier and refactor
ShuffleManagerClient creation (#1838)
### What changes were proposed in this pull request?
1. Introduce StatefulCloseable and ExpiringClosableSupplier
2. refactor ShuffleManagerClient to leverage ExpiringClosableSupplier
### Why are the changes needed?
For better code quality
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing UTs and new UTs.
---
.../apache/spark/shuffle/RssSparkShuffleUtils.java | 48 +++---
.../shuffle/reader/RssFetchFailedIterator.java | 63 +++-----
.../BlockIdSelfManagedShuffleWriteClient.java | 13 +-
.../uniffle/shuffle/RssShuffleClientFactory.java | 12 +-
.../shuffle/manager/RssShuffleManagerBase.java | 36 +++--
.../apache/spark/shuffle/RssShuffleManager.java | 18 ++-
.../spark/shuffle/reader/RssShuffleReader.java | 12 +-
.../spark/shuffle/writer/RssShuffleWriter.java | 71 ++++-----
.../spark/shuffle/reader/RssShuffleReaderTest.java | 6 +-
.../spark/shuffle/writer/RssShuffleWriterTest.java | 8 +
.../apache/spark/shuffle/RssShuffleManager.java | 9 +-
.../spark/shuffle/reader/RssShuffleReader.java | 11 +-
.../spark/shuffle/writer/RssShuffleWriter.java | 84 +++++-----
.../spark/shuffle/reader/RssShuffleReaderTest.java | 6 +
.../spark/shuffle/writer/RssShuffleWriterTest.java | 14 ++
.../common/util/ExpiringCloseableSupplier.java | 110 +++++++++++++
.../uniffle/common/util/StatefulCloseable.java | 25 +++
.../common/util/ExpiringCloseableSupplierTest.java | 172 +++++++++++++++++++++
.../uniffle/test/ShuffleServerManagerTestBase.java | 13 +-
.../uniffle/client/api/ShuffleManagerClient.java | 5 +-
.../factory/ShuffleManagerClientFactory.java | 4 +-
.../client/impl/grpc/ShuffleManagerGrpcClient.java | 20 ++-
.../factory/ShuffleManagerClientFactoryTest.java | 5 +-
23 files changed, 545 insertions(+), 220 deletions(-)
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
index b3763df32..feee2a331 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
@@ -17,7 +17,6 @@
package org.apache.spark.shuffle;
-import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.Arrays;
@@ -25,6 +24,7 @@ import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
+import java.util.function.Supplier;
import scala.Option;
import scala.reflect.ClassTag;
@@ -43,21 +43,18 @@ import org.slf4j.LoggerFactory;
import org.apache.uniffle.client.api.CoordinatorClient;
import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.factory.CoordinatorClientFactory;
-import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
-import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.exception.RssFetchFailedException;
import org.apache.uniffle.common.util.Constants;
import static
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;
-import static org.apache.uniffle.common.util.Constants.DRIVER_HOST;
public class RssSparkShuffleUtils {
@@ -346,6 +343,7 @@ public class RssSparkShuffleUtils {
}
public static RssException reportRssFetchFailedException(
+ Supplier<ShuffleManagerClient> managerClientSupplier,
RssFetchFailedException rssFetchFailedException,
SparkConf sparkConf,
String appId,
@@ -355,32 +353,24 @@ public class RssSparkShuffleUtils {
RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
if (rssConf.getBoolean(RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED)
&& RssSparkShuffleUtils.isStageResubmitSupported()) {
- String driver = rssConf.getString(DRIVER_HOST, "");
- int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
- try (ShuffleManagerClient client =
- ShuffleManagerClientFactory.getInstance()
- .createShuffleManagerClient(ClientType.GRPC, driver, port)) {
- // todo: Create a new rpc interface to report failures in batch.
- for (int partitionId : failedPartitions) {
- RssReportShuffleFetchFailureRequest req =
- new RssReportShuffleFetchFailureRequest(
- appId,
- shuffleId,
- stageAttemptId,
- partitionId,
- rssFetchFailedException.getMessage());
- RssReportShuffleFetchFailureResponse response =
client.reportShuffleFetchFailure(req);
- if (response.getReSubmitWholeStage()) {
- // since we are going to roll out the whole stage, mapIndex
shouldn't matter, hence -1
- // is provided.
- FetchFailedException ffe =
- RssSparkShuffleUtils.createFetchFailedException(
- shuffleId, -1, partitionId, rssFetchFailedException);
- return new RssException(ffe);
- }
+ for (int partitionId : failedPartitions) {
+ RssReportShuffleFetchFailureRequest req =
+ new RssReportShuffleFetchFailureRequest(
+ appId,
+ shuffleId,
+ stageAttemptId,
+ partitionId,
+ rssFetchFailedException.getMessage());
+ RssReportShuffleFetchFailureResponse response =
+ managerClientSupplier.get().reportShuffleFetchFailure(req);
+ if (response.getReSubmitWholeStage()) {
+ // since we are going to roll out the whole stage, mapIndex
shouldn't matter, hence -1
+ // is provided.
+ FetchFailedException ffe =
+ RssSparkShuffleUtils.createFetchFailedException(
+ shuffleId, -1, partitionId, rssFetchFailedException);
+ return new RssException(ffe);
}
- } catch (IOException ioe) {
- LOG.info("Error closing shuffle manager client with error:", ioe);
}
}
return rssFetchFailedException;
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java
index c394f510b..1bc61dc74 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java
@@ -17,8 +17,8 @@
package org.apache.spark.shuffle.reader;
-import java.io.IOException;
import java.util.Objects;
+import java.util.function.Supplier;
import scala.Product2;
import scala.collection.AbstractIterator;
@@ -30,10 +30,8 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.uniffle.client.api.ShuffleManagerClient;
-import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
-import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.exception.RssFetchFailedException;
@@ -52,8 +50,7 @@ public class RssFetchFailedIterator<K, C> extends
AbstractIterator<Product2<K, C
private int shuffleId;
private int partitionId;
private int stageAttemptId;
- private String reportServerHost;
- private int reportServerPort;
+ private Supplier<ShuffleManagerClient> managerClientSupplier;
private Builder() {}
@@ -77,19 +74,13 @@ public class RssFetchFailedIterator<K, C> extends
AbstractIterator<Product2<K, C
return this;
}
- Builder reportServerHost(String host) {
- this.reportServerHost = host;
- return this;
- }
-
- Builder port(int port) {
- this.reportServerPort = port;
+ Builder managerClientSupplier(Supplier<ShuffleManagerClient>
managerClientSupplier) {
+ this.managerClientSupplier = managerClientSupplier;
return this;
}
<K, C> RssFetchFailedIterator<K, C> build(Iterator<Product2<K, C>> iter) {
Objects.requireNonNull(this.appId);
- Objects.requireNonNull(this.reportServerHost);
return new RssFetchFailedIterator<>(this, iter);
}
}
@@ -98,37 +89,23 @@ public class RssFetchFailedIterator<K, C> extends
AbstractIterator<Product2<K, C
return new Builder();
}
- private static ShuffleManagerClient createShuffleManagerClient(String host,
int port)
- throws IOException {
- ClientType grpc = ClientType.GRPC;
- // host is passed from spark.driver.bindAddress, which would be set when
SparkContext is
- // constructed.
- return
ShuffleManagerClientFactory.getInstance().createShuffleManagerClient(grpc,
host, port);
- }
-
private RssException generateFetchFailedIfNecessary(RssFetchFailedException
e) {
- String driver = builder.reportServerHost;
- int port = builder.reportServerPort;
- // todo: reuse this manager client if this is a bottleneck.
- try (ShuffleManagerClient client = createShuffleManagerClient(driver,
port)) {
- RssReportShuffleFetchFailureRequest req =
- new RssReportShuffleFetchFailureRequest(
- builder.appId,
- builder.shuffleId,
- builder.stageAttemptId,
- builder.partitionId,
- e.getMessage());
- RssReportShuffleFetchFailureResponse response =
client.reportShuffleFetchFailure(req);
- if (response.getReSubmitWholeStage()) {
- // since we are going to roll out the whole stage, mapIndex shouldn't
matter, hence -1 is
- // provided.
- FetchFailedException ffe =
- RssSparkShuffleUtils.createFetchFailedException(
- builder.shuffleId, -1, builder.partitionId, e);
- return new RssException(ffe);
- }
- } catch (IOException ioe) {
- LOG.info("Error closing shuffle manager client with error:", ioe);
+ ShuffleManagerClient client = builder.managerClientSupplier.get();
+ RssReportShuffleFetchFailureRequest req =
+ new RssReportShuffleFetchFailureRequest(
+ builder.appId,
+ builder.shuffleId,
+ builder.stageAttemptId,
+ builder.partitionId,
+ e.getMessage());
+ RssReportShuffleFetchFailureResponse response =
client.reportShuffleFetchFailure(req);
+ if (response.getReSubmitWholeStage()) {
+ // since we are going to roll out the whole stage, mapIndex shouldn't
matter, hence -1 is
+ // provided.
+ FetchFailedException ffe =
+ RssSparkShuffleUtils.createFetchFailedException(
+ builder.shuffleId, -1, builder.partitionId, e);
+ return new RssException(ffe);
}
return e;
}
diff --git
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockIdSelfManagedShuffleWriteClient.java
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockIdSelfManagedShuffleWriteClient.java
index 1429bacbf..93aa3f0fc 100644
---
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockIdSelfManagedShuffleWriteClient.java
+++
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockIdSelfManagedShuffleWriteClient.java
@@ -22,6 +22,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
+import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
@@ -41,16 +42,16 @@ import org.apache.uniffle.common.util.BlockIdLayout;
* driver side.
*/
public class BlockIdSelfManagedShuffleWriteClient extends
ShuffleWriteClientImpl {
- private ShuffleManagerClient shuffleManagerClient;
+ private Supplier<ShuffleManagerClient> managerClientSupplier;
public BlockIdSelfManagedShuffleWriteClient(
RssShuffleClientFactory.ExtendWriteClientBuilder builder) {
super(builder);
- if (builder.getShuffleManagerClient() == null) {
+ if (builder.getManagerClientSupplier() == null) {
throw new RssException("Illegal empty shuffleManagerClient. This should
not happen");
}
- this.shuffleManagerClient = builder.getShuffleManagerClient();
+ this.managerClientSupplier = builder.getManagerClientSupplier();
}
@Override
@@ -73,7 +74,7 @@ public class BlockIdSelfManagedShuffleWriteClient extends
ShuffleWriteClientImpl
RssReportShuffleResultRequest request =
new RssReportShuffleResultRequest(
appId, shuffleId, taskAttemptId, partitionToBlockIds, bitmapNum);
- shuffleManagerClient.reportShuffleResult(request);
+ managerClientSupplier.get().reportShuffleResult(request);
}
@Override
@@ -85,7 +86,7 @@ public class BlockIdSelfManagedShuffleWriteClient extends
ShuffleWriteClientImpl
int partitionId) {
RssGetShuffleResultRequest request =
new RssGetShuffleResultRequest(appId, shuffleId, partitionId,
BlockIdLayout.DEFAULT);
- return shuffleManagerClient.getShuffleResult(request).getBlockIdBitmap();
+ return
managerClientSupplier.get().getShuffleResult(request).getBlockIdBitmap();
}
@Override
@@ -101,6 +102,6 @@ public class BlockIdSelfManagedShuffleWriteClient extends
ShuffleWriteClientImpl
RssGetShuffleResultForMultiPartRequest request =
new RssGetShuffleResultForMultiPartRequest(
appId, shuffleId, partitionIds, BlockIdLayout.DEFAULT);
- return
shuffleManagerClient.getShuffleResultForMultiPart(request).getBlockIdBitmap();
+ return
managerClientSupplier.get().getShuffleResultForMultiPart(request).getBlockIdBitmap();
}
}
diff --git
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/RssShuffleClientFactory.java
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/RssShuffleClientFactory.java
index c19d91324..bad10ab72 100644
---
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/RssShuffleClientFactory.java
+++
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/RssShuffleClientFactory.java
@@ -17,6 +17,8 @@
package org.apache.uniffle.shuffle;
+import java.util.function.Supplier;
+
import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
@@ -41,18 +43,18 @@ public class RssShuffleClientFactory extends
ShuffleClientFactory {
public static class ExtendWriteClientBuilder<T extends
ExtendWriteClientBuilder<T>>
extends WriteClientBuilder<T> {
private boolean blockIdSelfManagedEnabled;
- private ShuffleManagerClient shuffleManagerClient;
+ private Supplier<ShuffleManagerClient> managerClientSupplier;
public boolean isBlockIdSelfManagedEnabled() {
return blockIdSelfManagedEnabled;
}
- public ShuffleManagerClient getShuffleManagerClient() {
- return shuffleManagerClient;
+ public Supplier<ShuffleManagerClient> getManagerClientSupplier() {
+ return managerClientSupplier;
}
- public T shuffleManagerClient(ShuffleManagerClient client) {
- this.shuffleManagerClient = client;
+ public T managerClientSupplier(Supplier<ShuffleManagerClient>
managerClientSupplier) {
+ this.managerClientSupplier = managerClientSupplier;
return self();
}
diff --git
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
index 6a281db2e..d314b9bb6 100644
---
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
+++
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
@@ -31,6 +31,7 @@ import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
+import java.util.function.Supplier;
import java.util.stream.Collectors;
import com.google.common.annotations.VisibleForTesting;
@@ -78,10 +79,12 @@ import org.apache.uniffle.common.ShuffleAssignmentsInfo;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.ConfigOption;
+import org.apache.uniffle.common.config.RssBaseConf;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.common.util.ExpiringCloseableSupplier;
import org.apache.uniffle.common.util.RetryUtils;
import org.apache.uniffle.shuffle.BlockIdManager;
@@ -104,7 +107,7 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
protected String clientType;
protected SparkConf sparkConf;
- protected ShuffleManagerClient shuffleManagerClient;
+ protected Supplier<ShuffleManagerClient> managerClientSupplier;
protected boolean rssStageRetryEnabled;
protected boolean rssStageRetryForWriteFailureEnabled;
protected boolean rssStageRetryForFetchFailureEnabled;
@@ -588,7 +591,8 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest =
new RssPartitionToShuffleServerRequest(shuffleId);
RssReassignOnStageRetryResponse rpcPartitionToShufflerServer =
- getOrCreateShuffleManagerClient()
+ getOrCreateShuffleManagerClientSupplier()
+ .get()
.getPartitionToShufflerServerWithStageRetry(rssPartitionToShuffleServerRequest);
StageAttemptShuffleHandleInfo shuffleHandleInfo =
StageAttemptShuffleHandleInfo.fromProto(
@@ -607,25 +611,27 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest =
new RssPartitionToShuffleServerRequest(shuffleId);
RssReassignOnBlockSendFailureResponse rpcPartitionToShufflerServer =
- getOrCreateShuffleManagerClient()
+ getOrCreateShuffleManagerClientSupplier()
+ .get()
.getPartitionToShufflerServerWithBlockRetry(rssPartitionToShuffleServerRequest);
MutableShuffleHandleInfo shuffleHandleInfo =
MutableShuffleHandleInfo.fromProto(rpcPartitionToShufflerServer.getHandle());
return shuffleHandleInfo;
}
- // todo: automatic close client when the client is idle to avoid too much
connections for spark
- // driver.
- protected ShuffleManagerClient getOrCreateShuffleManagerClient() {
- if (shuffleManagerClient == null) {
+ protected synchronized Supplier<ShuffleManagerClient>
getOrCreateShuffleManagerClientSupplier() {
+ if (managerClientSupplier == null) {
RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
String driver = rssConf.getString("driver.host", "");
int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
- this.shuffleManagerClient =
- ShuffleManagerClientFactory.getInstance()
- .createShuffleManagerClient(ClientType.GRPC, driver, port);
+ long rpcTimeout =
rssConf.getLong(RssBaseConf.RSS_CLIENT_TYPE_GRPC_TIMEOUT_MS);
+ this.managerClientSupplier =
+ ExpiringCloseableSupplier.of(
+ () ->
+ ShuffleManagerClientFactory.getInstance()
+ .createShuffleManagerClient(ClientType.GRPC, driver,
port, rpcTimeout));
}
- return shuffleManagerClient;
+ return managerClientSupplier;
}
@Override
@@ -808,6 +814,14 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
}
}
+ @Override
+ public void stop() {
+ if (managerClientSupplier != null
+ && managerClientSupplier instanceof ExpiringCloseableSupplier) {
+ ((ExpiringCloseableSupplier<ShuffleManagerClient>)
managerClientSupplier).close();
+ }
+ }
+
/**
* Creating the shuffleAssignmentInfo from the servers and partitionIds
*
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 1e5bb4941..27db614bf 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -214,16 +214,15 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
}
}
}
-
if (shuffleManagerRpcServiceEnabled) {
- this.shuffleManagerClient = getOrCreateShuffleManagerClient();
+ getOrCreateShuffleManagerClientSupplier();
}
this.shuffleWriteClient =
RssShuffleClientFactory.getInstance()
.createShuffleWriteClient(
RssShuffleClientFactory.newWriteBuilder()
.blockIdSelfManagedEnabled(blockIdSelfManagedEnabled)
- .shuffleManagerClient(shuffleManagerClient)
+ .managerClientSupplier(managerClientSupplier)
.clientType(clientType)
.retryMax(retryMax)
.retryIntervalMax(retryIntervalMax)
@@ -434,6 +433,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
this,
sparkConf,
shuffleWriteClient,
+ managerClientSupplier,
rssHandle,
this::markFailedTask,
context);
@@ -537,7 +537,8 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
blockIdBitmap,
taskIdBitmap,
RssSparkConfig.toRssConf(sparkConf),
- partitionToServers);
+ partitionToServers,
+ managerClientSupplier);
} else {
throw new RssException("Unexpected ShuffleHandle:" +
handle.getClass().getName());
}
@@ -573,6 +574,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
@Override
public void stop() {
+ super.stop();
if (heartBeatScheduledExecutorService != null) {
heartBeatScheduledExecutorService.shutdownNow();
}
@@ -719,7 +721,13 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
clientType, shuffleServerInfoSet, appId, shuffleId, partitionId);
} catch (RssFetchFailedException e) {
throw RssSparkShuffleUtils.reportRssFetchFailedException(
- e, sparkConf, appId, shuffleId, stageAttemptId,
Sets.newHashSet(partitionId));
+ managerClientSupplier,
+ e,
+ sparkConf,
+ appId,
+ shuffleId,
+ stageAttemptId,
+ Sets.newHashSet(partitionId));
}
}
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index 3bf5840e8..4b4ec32c5 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -19,6 +19,7 @@ package org.apache.spark.shuffle.reader;
import java.util.List;
import java.util.Map;
+import java.util.function.Supplier;
import scala.Function0;
import scala.Function2;
@@ -47,6 +48,7 @@ import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.api.ShuffleReadClient;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.util.RssClientConfig;
@@ -77,6 +79,7 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
private List<ShuffleServerInfo> shuffleServerInfoList;
private Configuration hadoopConf;
private RssConf rssConf;
+ private Supplier<ShuffleManagerClient> managerClientSupplier;
public RssShuffleReader(
int startPartition,
@@ -90,7 +93,8 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
Roaring64NavigableMap blockIdBitmap,
Roaring64NavigableMap taskIdBitmap,
RssConf rssConf,
- Map<Integer, List<ShuffleServerInfo>> partitionToServers) {
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers,
+ Supplier<ShuffleManagerClient> managerClientSupplier) {
this.appId = rssShuffleHandle.getAppId();
this.startPartition = startPartition;
this.endPartition = endPartition;
@@ -107,6 +111,7 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
this.hadoopConf = hadoopConf;
this.shuffleServerInfoList = (List<ShuffleServerInfo>)
(partitionToServers.get(startPartition));
this.rssConf = rssConf;
+ this.managerClientSupplier = managerClientSupplier;
expectedTaskIdsBitmapFilterEnable = shuffleServerInfoList.size() > 1;
}
@@ -235,16 +240,13 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
// stage re-compute and shuffle manager server port are both set
if (rssConf.getBoolean(RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED)
&& rssConf.getInteger(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT, 0) > 0)
{
- String driver = rssConf.getString("driver.host", "");
- int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
resultIter =
RssFetchFailedIterator.newBuilder()
.appId(appId)
.shuffleId(shuffleId)
.partitionId(startPartition)
.stageAttemptId(context.stageAttemptNumber())
- .reportServerHost(driver)
- .port(port)
+ .managerClientSupplier(managerClientSupplier)
.build(resultIter);
}
return resultIter;
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 5ac6a7e9e..4474c99c8 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -17,7 +17,6 @@
package org.apache.spark.shuffle.writer;
-import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@@ -31,6 +30,7 @@ import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
+import java.util.function.Supplier;
import java.util.stream.Collectors;
import scala.Function1;
@@ -64,17 +64,13 @@ import org.slf4j.LoggerFactory;
import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.api.ShuffleWriteClient;
-import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.client.request.RssReassignServersRequest;
import org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest;
import org.apache.uniffle.client.response.RssReassignServersResponse;
import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse;
-import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
-import org.apache.uniffle.common.config.RssClientConf;
-import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.exception.RssSendFailedException;
import org.apache.uniffle.common.exception.RssWaitFailedException;
@@ -114,6 +110,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
private final Set<Long> blockIds = Sets.newConcurrentHashSet();
private TaskContext taskContext;
private SparkConf sparkConf;
+ private Supplier<ShuffleManagerClient> managerClientSupplier;
public RssShuffleWriter(
String appId,
@@ -125,6 +122,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
RssShuffleManager shuffleManager,
SparkConf sparkConf,
ShuffleWriteClient shuffleWriteClient,
+ Supplier<ShuffleManagerClient> managerClientSupplier,
RssShuffleHandle<K, V, C> rssHandle,
SimpleShuffleHandleInfo shuffleHandleInfo,
TaskContext context) {
@@ -137,6 +135,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
shuffleManager,
sparkConf,
shuffleWriteClient,
+ managerClientSupplier,
rssHandle,
(tid) -> true,
shuffleHandleInfo,
@@ -153,6 +152,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
RssShuffleManager shuffleManager,
SparkConf sparkConf,
ShuffleWriteClient shuffleWriteClient,
+ Supplier<ShuffleManagerClient> managerClientSupplier,
RssShuffleHandle<K, V, C> rssHandle,
Function<String, Boolean> taskFailureCallback,
ShuffleHandleInfo shuffleHandleInfo,
@@ -172,6 +172,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
this.bitmapSplitNum =
sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM);
this.serverToPartitionToBlockIds = Maps.newHashMap();
this.shuffleWriteClient = shuffleWriteClient;
+ this.managerClientSupplier = managerClientSupplier;
this.shuffleServersForData = shuffleHandleInfo.getServers();
this.partitionToServers =
shuffleHandleInfo.getAvailablePartitionServersForWriter();
this.isMemoryShuffleEnabled =
@@ -191,6 +192,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
RssShuffleManager shuffleManager,
SparkConf sparkConf,
ShuffleWriteClient shuffleWriteClient,
+ Supplier<ShuffleManagerClient> managerClientSupplier,
RssShuffleHandle<K, V, C> rssHandle,
Function<String, Boolean> taskFailureCallback,
TaskContext context) {
@@ -203,6 +205,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
shuffleManager,
sparkConf,
shuffleWriteClient,
+ managerClientSupplier,
rssHandle,
taskFailureCallback,
shuffleManager.getShuffleHandleInfo(rssHandle),
@@ -528,14 +531,6 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
return shuffleWriteMetrics;
}
- private static ShuffleManagerClient createShuffleManagerClient(String host,
int port)
- throws IOException {
- ClientType grpc = ClientType.GRPC;
- // Host can be inferred from `spark.driver.bindAddress`, which would be
set when SparkContext is
- // constructed.
- return
ShuffleManagerClientFactory.getInstance().createShuffleManagerClient(grpc,
host, port);
- }
-
private void throwFetchFailedIfNecessary(Exception e) {
// The shuffleServer is registered only when a Block fails to be sent
if (e instanceof RssSendFailedException) {
@@ -550,34 +545,28 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
taskContext.stageAttemptNumber(),
shuffleServerInfos,
e.getMessage());
- RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
- String driver = rssConf.getString("driver.host", "");
- int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
- try (ShuffleManagerClient shuffleManagerClient =
createShuffleManagerClient(driver, port)) {
- RssReportShuffleWriteFailureResponse response =
- shuffleManagerClient.reportShuffleWriteFailure(req);
- if (response.getReSubmitWholeStage()) {
- // The shuffle server is reassigned.
- RssReassignServersRequest rssReassignServersRequest =
- new RssReassignServersRequest(
- taskContext.stageId(),
- taskContext.stageAttemptNumber(),
- shuffleId,
- partitioner.numPartitions());
- RssReassignServersResponse rssReassignServersResponse =
-
shuffleManagerClient.reassignOnStageResubmit(rssReassignServersRequest);
- LOG.info(
- "Whether the reassignment is successful: {}",
- rssReassignServersResponse.isNeedReassign());
- // since we are going to roll out the whole stage, mapIndex
shouldn't matter, hence -1 is
- // provided.
- FetchFailedException ffe =
- RssSparkShuffleUtils.createFetchFailedException(
- shuffleId, -1, taskContext.stageAttemptNumber(), e);
- throw new RssException(ffe);
- }
- } catch (IOException ioe) {
- LOG.info("Error closing shuffle manager client with error:", ioe);
+ RssReportShuffleWriteFailureResponse response =
+ managerClientSupplier.get().reportShuffleWriteFailure(req);
+ if (response.getReSubmitWholeStage()) {
+ // The shuffle server is reassigned.
+ RssReassignServersRequest rssReassignServersRequest =
+ new RssReassignServersRequest(
+ taskContext.stageId(),
+ taskContext.stageAttemptNumber(),
+ shuffleId,
+ partitioner.numPartitions());
+ RssReassignServersResponse rssReassignServersResponse =
+
managerClientSupplier.get().reassignOnStageResubmit(rssReassignServersRequest);
+ LOG.info(
+ "Whether the reassignment is successful: {}",
+ rssReassignServersResponse.isNeedReassign());
+ // since we are going to roll out the whole stage, mapIndex shouldn't
matter, hence -1
+ // is
+ // provided.
+ FetchFailedException ffe =
+ RssSparkShuffleUtils.createFetchFailedException(
+ shuffleId, -1, taskContext.stageAttemptNumber(), e);
+ throw new RssException(ffe);
}
}
throw new RssException(e);
diff --git
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
index f09223b1c..78fe7dec0 100644
---
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
+++
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
@@ -35,9 +35,11 @@ import org.apache.spark.shuffle.RssShuffleHandle;
import org.junit.jupiter.api.Test;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
+import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.util.ExpiringCloseableSupplier;
import org.apache.uniffle.storage.handler.impl.HadoopShuffleWriteHandler;
import org.apache.uniffle.storage.util.StorageType;
@@ -85,6 +87,7 @@ public class RssShuffleReaderTest extends
AbstractRssReaderTest {
rssConf.set(RssClientConf.RSS_STORAGE_TYPE, StorageType.HDFS.name());
rssConf.set(RssClientConf.RSS_INDEX_READ_LIMIT, 1000);
rssConf.set(RssClientConf.RSS_CLIENT_READ_BUFFER_SIZE, "1000");
+ final ShuffleManagerClient mockShuffleManagerClient =
mock(ShuffleManagerClient.class);
RssShuffleReader<String, String> rssShuffleReaderSpy =
spy(
new RssShuffleReader<>(
@@ -99,7 +102,8 @@ public class RssShuffleReaderTest extends
AbstractRssReaderTest {
blockIdBitmap,
taskIdBitmap,
rssConf,
- partitionToServers));
+ partitionToServers,
+ ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient)));
validateResult(rssShuffleReaderSpy.read(), expectedData, 10);
}
diff --git
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index e039ad9d5..779f94117 100644
---
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -48,12 +48,14 @@ import org.apache.spark.shuffle.RssSparkConfig;
import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
import org.junit.jupiter.api.Test;
+import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.common.util.ExpiringCloseableSupplier;
import org.apache.uniffle.storage.util.StorageType;
import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -88,6 +90,7 @@ public class RssShuffleWriterTest {
Serializer kryoSerializer = new KryoSerializer(conf);
ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class);
+ ShuffleManagerClient mockShuffleManagerClient =
mock(ShuffleManagerClient.class);
Partitioner mockPartitioner = mock(Partitioner.class);
ShuffleDependency<String, String, String> mockDependency =
mock(ShuffleDependency.class);
RssShuffleHandle<String, String, String> mockHandle =
mock(RssShuffleHandle.class);
@@ -124,6 +127,7 @@ public class RssShuffleWriterTest {
manager,
conf,
mockShuffleWriteClient,
+ ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
mockHandle,
mockShuffleHandleInfo,
contextMock);
@@ -234,6 +238,7 @@ public class RssShuffleWriterTest {
Partitioner mockPartitioner = mock(Partitioner.class);
ShuffleDependency<String, String, String> mockDependency =
mock(ShuffleDependency.class);
final ShuffleWriteClient mockShuffleWriteClient =
mock(ShuffleWriteClient.class);
+ final ShuffleManagerClient mockShuffleManagerClient =
mock(ShuffleManagerClient.class);
RssShuffleHandle<String, String, String> mockHandle =
mock(RssShuffleHandle.class);
when(mockHandle.getDependency()).thenReturn(mockDependency);
Serializer kryoSerializer = new KryoSerializer(conf);
@@ -299,6 +304,7 @@ public class RssShuffleWriterTest {
manager,
conf,
mockShuffleWriteClient,
+ ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
mockHandle,
mockShuffleHandleInfo,
contextMock);
@@ -348,6 +354,7 @@ public class RssShuffleWriterTest {
@Test
public void postBlockEventTest() throws Exception {
final ShuffleWriteMetrics mockMetrics = mock(ShuffleWriteMetrics.class);
+ final ShuffleManagerClient mockShuffleManagerClient =
mock(ShuffleManagerClient.class);
ShuffleDependency<String, String, String> mockDependency =
mock(ShuffleDependency.class);
Partitioner mockPartitioner = mock(Partitioner.class);
when(mockDependency.partitioner()).thenReturn(mockPartitioner);
@@ -411,6 +418,7 @@ public class RssShuffleWriterTest {
manager,
conf,
mockWriteClient,
+ ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
mockHandle,
mockShuffleHandleInfo,
contextMock);
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index bf42bf361..92e630df2 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -239,7 +239,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
}
}
if (shuffleManagerRpcServiceEnabled) {
- this.shuffleManagerClient = getOrCreateShuffleManagerClient();
+ getOrCreateShuffleManagerClientSupplier();
}
int unregisterThreadPoolSize =
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE);
@@ -253,7 +253,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
.createShuffleWriteClient(
RssShuffleClientFactory.newWriteBuilder()
.blockIdSelfManagedEnabled(blockIdSelfManagedEnabled)
- .shuffleManagerClient(shuffleManagerClient)
+ .managerClientSupplier(managerClientSupplier)
.clientType(clientType)
.retryMax(retryMax)
.retryIntervalMax(retryIntervalMax)
@@ -523,6 +523,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
this,
sparkConf,
shuffleWriteClient,
+ managerClientSupplier,
rssHandle,
this::markFailedTask,
context);
@@ -696,6 +697,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
blockIdBitmap, startPartition, endPartition, blockIdLayout),
taskIdBitmap,
readMetrics,
+ managerClientSupplier,
RssSparkConfig.toRssConf(sparkConf),
dataDistributionType,
shuffleHandleInfo.getAllPartitionServersForReader());
@@ -853,6 +855,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
@Override
public void stop() {
+ super.stop();
if (heartBeatScheduledExecutorService != null) {
heartBeatScheduledExecutorService.shutdownNow();
}
@@ -1031,7 +1034,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
replicaRequirementTracking);
} catch (RssFetchFailedException e) {
throw RssSparkShuffleUtils.reportRssFetchFailedException(
- e, sparkConf, appId, shuffleId, stageAttemptId, failedPartitions);
+ managerClientSupplier, e, sparkConf, appId, shuffleId,
stageAttemptId, failedPartitions);
}
}
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index bf47ced6b..19682bd65 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -19,6 +19,7 @@ package org.apache.spark.shuffle.reader;
import java.util.List;
import java.util.Map;
+import java.util.function.Supplier;
import scala.Function0;
import scala.Function1;
@@ -49,6 +50,7 @@ import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.api.ShuffleReadClient;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.util.RssClientConfig;
@@ -58,7 +60,6 @@ import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import static
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;
-import static org.apache.uniffle.common.util.Constants.DRIVER_HOST;
public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
private static final Logger LOG =
LoggerFactory.getLogger(RssShuffleReader.class);
@@ -83,6 +84,7 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
private ShuffleReadMetrics readMetrics;
private RssConf rssConf;
private ShuffleDataDistributionType dataDistributionType;
+ private Supplier<ShuffleManagerClient> managerClientSupplier;
public RssShuffleReader(
int startPartition,
@@ -97,6 +99,7 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
Map<Integer, Roaring64NavigableMap> partitionToExpectBlocks,
Roaring64NavigableMap taskIdBitmap,
ShuffleReadMetrics readMetrics,
+ Supplier<ShuffleManagerClient> managerClientSupplier,
RssConf rssConf,
ShuffleDataDistributionType dataDistributionType,
Map<Integer, List<ShuffleServerInfo>> allPartitionToServers) {
@@ -120,6 +123,7 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
this.partitionToShuffleServers = allPartitionToServers;
this.rssConf = rssConf;
this.dataDistributionType = dataDistributionType;
+ this.managerClientSupplier = managerClientSupplier;
}
@Override
@@ -193,16 +197,13 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
// resubmit stage and shuffle manager server port are both set
if (rssConf.getBoolean(RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED)
&& rssConf.getInteger(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT, 0) > 0)
{
- String driver = rssConf.getString(DRIVER_HOST, "");
- int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
resultIter =
RssFetchFailedIterator.newBuilder()
.appId(appId)
.shuffleId(shuffleId)
.partitionId(startPartition)
.stageAttemptId(context.stageAttemptNumber())
- .reportServerHost(driver)
- .port(port)
+ .managerClientSupplier(managerClientSupplier)
.build(resultIter);
}
return resultIter;
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 870141c4b..24a3b8c1c 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -17,7 +17,6 @@
package org.apache.spark.shuffle.writer;
-import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@@ -36,6 +35,7 @@ import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
+import java.util.function.Supplier;
import java.util.stream.Collectors;
import scala.Function1;
@@ -71,7 +71,6 @@ import org.slf4j.LoggerFactory;
import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.api.ShuffleWriteClient;
-import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.client.impl.TrackingBlockStatus;
import org.apache.uniffle.client.request.RssReassignOnBlockSendFailureRequest;
@@ -80,12 +79,10 @@ import
org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest;
import
org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse;
import org.apache.uniffle.client.response.RssReassignServersResponse;
import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse;
-import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.ReceivingFailureServer;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssClientConf;
-import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.exception.RssSendFailedException;
import org.apache.uniffle.common.exception.RssWaitFailedException;
@@ -143,6 +140,8 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
private static final Set<StatusCode> STATUS_CODE_WITHOUT_BLOCK_RESEND =
Sets.newHashSet(StatusCode.NO_REGISTER);
+ private final Supplier<ShuffleManagerClient> managerClientSupplier;
+
// Only for tests
@VisibleForTesting
public RssShuffleWriter(
@@ -155,6 +154,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
RssShuffleManager shuffleManager,
SparkConf sparkConf,
ShuffleWriteClient shuffleWriteClient,
+ Supplier<ShuffleManagerClient> managerClientSupplier,
RssShuffleHandle<K, V, C> rssHandle,
ShuffleHandleInfo shuffleHandleInfo,
TaskContext context) {
@@ -167,6 +167,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
shuffleManager,
sparkConf,
shuffleWriteClient,
+ managerClientSupplier,
rssHandle,
(tid) -> true,
shuffleHandleInfo,
@@ -184,6 +185,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
RssShuffleManager shuffleManager,
SparkConf sparkConf,
ShuffleWriteClient shuffleWriteClient,
+ Supplier<ShuffleManagerClient> managerClientSupplier,
RssShuffleHandle<K, V, C> rssHandle,
Function<String, Boolean> taskFailureCallback,
ShuffleHandleInfo shuffleHandleInfo,
@@ -217,6 +219,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
this.shuffleHandleInfo = shuffleHandleInfo;
this.taskContext = context;
this.sparkConf = sparkConf;
+ this.managerClientSupplier = managerClientSupplier;
this.blockFailSentRetryEnabled =
sparkConf.getBoolean(
RssSparkConfig.SPARK_RSS_CONFIG_PREFIX
@@ -235,6 +238,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
RssShuffleManager shuffleManager,
SparkConf sparkConf,
ShuffleWriteClient shuffleWriteClient,
+ Supplier<ShuffleManagerClient> managerClientSupplier,
RssShuffleHandle<K, V, C> rssHandle,
Function<String, Boolean> taskFailureCallback,
TaskContext context) {
@@ -247,6 +251,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
shuffleManager,
sparkConf,
shuffleWriteClient,
+ managerClientSupplier,
rssHandle,
taskFailureCallback,
shuffleManager.getShuffleHandleInfo(rssHandle),
@@ -618,14 +623,11 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
LOG.info(
"Initiate reassignOnBlockSendFailure. failure partition servers: {}",
failurePartitionToServers);
- RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
- String driver = rssConf.getString("driver.host", "");
- int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
- try (ShuffleManagerClient shuffleManagerClient =
createShuffleManagerClient(driver, port)) {
- String executorId = SparkEnv.get().executorId();
- long taskAttemptId = taskContext.taskAttemptId();
- int stageId = taskContext.stageId();
- int stageAttemptNum = taskContext.stageAttemptNumber();
+ String executorId = SparkEnv.get().executorId();
+ long taskAttemptId = taskContext.taskAttemptId();
+ int stageId = taskContext.stageId();
+ int stageAttemptNum = taskContext.stageAttemptNumber();
+ try {
RssReassignOnBlockSendFailureRequest request =
new RssReassignOnBlockSendFailureRequest(
shuffleId,
@@ -635,7 +637,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
stageId,
stageAttemptNum);
RssReassignOnBlockSendFailureResponse response =
- shuffleManagerClient.reassignOnBlockSendFailure(request);
+ managerClientSupplier.get().reassignOnBlockSendFailure(request);
if (response.getStatusCode() != StatusCode.SUCCESS) {
String msg =
String.format(
@@ -835,14 +837,6 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
return bufferManager;
}
- private static ShuffleManagerClient createShuffleManagerClient(String host,
int port)
- throws IOException {
- ClientType grpc = ClientType.GRPC;
- // Host can be inferred from `spark.driver.bindAddress`, which would be
set when SparkContext is
- // constructed.
- return
ShuffleManagerClientFactory.getInstance().createShuffleManagerClient(grpc,
host, port);
- }
-
private void throwFetchFailedIfNecessary(Exception e) {
// The shuffleServer is registered only when a Block fails to be sent
if (e instanceof RssSendFailedException) {
@@ -857,33 +851,27 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
taskContext.stageAttemptNumber(),
shuffleServerInfos,
e.getMessage());
- RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
- String driver = rssConf.getString("driver.host", "");
- int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
- try (ShuffleManagerClient shuffleManagerClient =
createShuffleManagerClient(driver, port)) {
- RssReportShuffleWriteFailureResponse response =
- shuffleManagerClient.reportShuffleWriteFailure(req);
- if (response.getReSubmitWholeStage()) {
- RssReassignServersRequest rssReassignServersRequest =
- new RssReassignServersRequest(
- taskContext.stageId(),
- taskContext.stageAttemptNumber(),
- shuffleId,
- partitioner.numPartitions());
- RssReassignServersResponse rssReassignServersResponse =
-
shuffleManagerClient.reassignOnStageResubmit(rssReassignServersRequest);
- LOG.info(
- "Whether the reassignment is successful: {}",
- rssReassignServersResponse.isNeedReassign());
- // since we are going to roll out the whole stage, mapIndex
shouldn't matter, hence -1 is
- // provided.
- FetchFailedException ffe =
- RssSparkShuffleUtils.createFetchFailedException(
- shuffleId, -1, taskContext.stageAttemptNumber(), e);
- throw new RssException(ffe);
- }
- } catch (IOException ioe) {
- LOG.info("Error closing shuffle manager client with error:", ioe);
+ RssReportShuffleWriteFailureResponse response =
+ managerClientSupplier.get().reportShuffleWriteFailure(req);
+ if (response.getReSubmitWholeStage()) {
+ RssReassignServersRequest rssReassignServersRequest =
+ new RssReassignServersRequest(
+ taskContext.stageId(),
+ taskContext.stageAttemptNumber(),
+ shuffleId,
+ partitioner.numPartitions());
+ RssReassignServersResponse rssReassignServersResponse =
+
managerClientSupplier.get().reassignOnStageResubmit(rssReassignServersRequest);
+ LOG.info(
+ "Whether the reassignment is successful: {}",
+ rssReassignServersResponse.isNeedReassign());
+ // since we are going to roll out the whole stage, mapIndex shouldn't
matter, hence -1
+ // is
+ // provided.
+ FetchFailedException ffe =
+ RssSparkShuffleUtils.createFetchFailedException(
+ shuffleId, -1, taskContext.stageAttemptNumber(), e);
+ throw new RssException(ffe);
}
}
throw new RssException(e);
diff --git
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
index aaff4cb8e..bc77f7192 100644
---
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
+++
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
@@ -36,10 +36,12 @@ import org.apache.spark.shuffle.RssShuffleHandle;
import org.junit.jupiter.api.Test;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
+import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.util.ExpiringCloseableSupplier;
import org.apache.uniffle.storage.handler.impl.HadoopShuffleWriteHandler;
import org.apache.uniffle.storage.util.StorageType;
@@ -93,6 +95,7 @@ public class RssShuffleReaderTest extends
AbstractRssReaderTest {
rssConf.set(RssClientConf.RSS_STORAGE_TYPE, StorageType.HDFS.name());
rssConf.set(RssClientConf.RSS_INDEX_READ_LIMIT, 1000);
rssConf.set(RssClientConf.RSS_CLIENT_READ_BUFFER_SIZE, "1000");
+ ShuffleManagerClient mockShuffleManagerClient =
mock(ShuffleManagerClient.class);
RssShuffleReader<String, String> rssShuffleReaderSpy =
spy(
new RssShuffleReader<>(
@@ -108,6 +111,7 @@ public class RssShuffleReaderTest extends
AbstractRssReaderTest {
partitionToExpectBlocks,
taskIdBitmap,
new ShuffleReadMetrics(),
+ ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
rssConf,
ShuffleDataDistributionType.NORMAL,
partitionToServers));
@@ -131,6 +135,7 @@ public class RssShuffleReaderTest extends
AbstractRssReaderTest {
partitionToExpectBlocks,
taskIdBitmap,
new ShuffleReadMetrics(),
+ ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
rssConf,
ShuffleDataDistributionType.NORMAL,
partitionToServers));
@@ -151,6 +156,7 @@ public class RssShuffleReaderTest extends
AbstractRssReaderTest {
partitionToExpectBlocks,
Roaring64NavigableMap.bitmapOf(),
new ShuffleReadMetrics(),
+ ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
rssConf,
ShuffleDataDistributionType.NORMAL,
partitionToServers));
diff --git
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index 53a8e7143..a4317aae8 100644
---
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -55,12 +55,14 @@ import
org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
import org.awaitility.Awaitility;
import org.junit.jupiter.api.Test;
+import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.common.util.ExpiringCloseableSupplier;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.storage.util.StorageType;
@@ -133,6 +135,7 @@ public class RssShuffleWriterTest {
Serializer kryoSerializer = new KryoSerializer(conf);
Partitioner mockPartitioner = mock(Partitioner.class);
final ShuffleWriteClient mockShuffleWriteClient =
mock(ShuffleWriteClient.class);
+ final ShuffleManagerClient mockShuffleManagerClient =
mock(ShuffleManagerClient.class);
ShuffleDependency<String, String, String> mockDependency =
mock(ShuffleDependency.class);
RssShuffleHandle<String, String, String> mockHandle =
mock(RssShuffleHandle.class);
when(mockHandle.getDependency()).thenReturn(mockDependency);
@@ -179,6 +182,7 @@ public class RssShuffleWriterTest {
manager,
conf,
mockShuffleWriteClient,
+ ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
mockHandle,
shuffleHandle,
contextMock);
@@ -385,6 +389,7 @@ public class RssShuffleWriterTest {
Serializer kryoSerializer = new KryoSerializer(conf);
Partitioner mockPartitioner = mock(Partitioner.class);
final ShuffleWriteClient mockShuffleWriteClient =
mock(ShuffleWriteClient.class);
+ final ShuffleManagerClient mockShuffleManagerClient =
mock(ShuffleManagerClient.class);
ShuffleDependency<String, String, String> mockDependency =
mock(ShuffleDependency.class);
RssShuffleHandle<String, String, String> mockHandle =
mock(RssShuffleHandle.class);
when(mockHandle.getDependency()).thenReturn(mockDependency);
@@ -450,6 +455,7 @@ public class RssShuffleWriterTest {
manager,
conf,
mockShuffleWriteClient,
+ ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
mockHandle,
shuffleHandleInfo,
contextMock);
@@ -552,6 +558,7 @@ public class RssShuffleWriterTest {
conf, false, null, successBlocks, taskToFailedBlockSendTracker);
ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class);
+ ShuffleManagerClient mockShuffleManagerClient =
mock(ShuffleManagerClient.class);
Partitioner mockPartitioner = mock(Partitioner.class);
RssShuffleHandle<String, String, String> mockHandle =
mock(RssShuffleHandle.class);
ShuffleDependency<String, String, String> mockDependency =
mock(ShuffleDependency.class);
@@ -587,6 +594,7 @@ public class RssShuffleWriterTest {
manager,
conf,
mockShuffleWriteClient,
+ ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
mockHandle,
mockShuffleHandleInfo,
contextMock);
@@ -714,6 +722,7 @@ public class RssShuffleWriterTest {
Serializer kryoSerializer = new KryoSerializer(conf);
Partitioner mockPartitioner = mock(Partitioner.class);
final ShuffleWriteClient mockShuffleWriteClient =
mock(ShuffleWriteClient.class);
+ final ShuffleManagerClient mockShuffleManagerClient =
mock(ShuffleManagerClient.class);
ShuffleDependency<String, String, String> mockDependency =
mock(ShuffleDependency.class);
RssShuffleHandle<String, String, String> mockHandle =
mock(RssShuffleHandle.class);
when(mockHandle.getDependency()).thenReturn(mockDependency);
@@ -734,6 +743,7 @@ public class RssShuffleWriterTest {
manager,
conf,
mockShuffleWriteClient,
+ ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
mockHandle,
mockShuffleHandleInfo,
contextMock);
@@ -794,6 +804,7 @@ public class RssShuffleWriterTest {
Serializer kryoSerializer = new KryoSerializer(conf);
Partitioner mockPartitioner = mock(Partitioner.class);
final ShuffleWriteClient mockShuffleWriteClient =
mock(ShuffleWriteClient.class);
+ final ShuffleManagerClient mockShuffleManagerClient =
mock(ShuffleManagerClient.class);
ShuffleDependency<String, String, String> mockDependency =
mock(ShuffleDependency.class);
RssShuffleHandle<String, String, String> mockHandle =
mock(RssShuffleHandle.class);
when(mockHandle.getDependency()).thenReturn(mockDependency);
@@ -857,6 +868,7 @@ public class RssShuffleWriterTest {
manager,
conf,
mockShuffleWriteClient,
+ ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
mockHandle,
mockShuffleHandleInfo,
contextMock);
@@ -958,6 +970,7 @@ public class RssShuffleWriterTest {
TaskContext contextMock = mock(TaskContext.class);
SimpleShuffleHandleInfo mockShuffleHandleInfo =
mock(SimpleShuffleHandleInfo.class);
ShuffleWriteClient mockWriteClient = mock(ShuffleWriteClient.class);
+ ShuffleManagerClient mockShuffleManagerClient =
mock(ShuffleManagerClient.class);
List<ShuffleBlockInfo> shuffleBlockInfoList = createShuffleBlockList(1,
31);
RssShuffleWriter<String, String, String> writer =
@@ -971,6 +984,7 @@ public class RssShuffleWriterTest {
mockShuffleManager,
conf,
mockWriteClient,
+ ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
mockHandle,
mockShuffleHandleInfo,
contextMock);
diff --git
a/common/src/main/java/org/apache/uniffle/common/util/ExpiringCloseableSupplier.java
b/common/src/main/java/org/apache/uniffle/common/util/ExpiringCloseableSupplier.java
new file mode 100644
index 000000000..f36f9be0c
--- /dev/null
+++
b/common/src/main/java/org/apache/uniffle/common/util/ExpiringCloseableSupplier.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.util;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ScheduledFuture;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Supplier;
+
+import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A Supplier for T cacheable and autocloseable with delay by using
ExpiringCloseableSupplier to
+ * obtain an object, manual closure may not be necessary.
+ */
+public class ExpiringCloseableSupplier<T extends StatefulCloseable>
+ implements Supplier<T>, Serializable {
+ private static final long serialVersionUID = 0;
+ private static final Logger LOG =
LoggerFactory.getLogger(ExpiringCloseableSupplier.class);
+ private static final int DEFAULT_DELAY_CLOSE_INTERVAL = 60000;
+ private static final ScheduledExecutorService executor =
+
ThreadUtils.getDaemonSingleThreadScheduledExecutor("ExpiringCloseableSupplier");
+
+ private final Supplier<T> delegate;
+ private final long delayCloseInterval;
+
+ private transient volatile ScheduledFuture<?> future;
+
+ @SuppressFBWarnings("SE_TRANSIENT_FIELD_NOT_RESTORED")
+ private transient volatile long accessTime = System.currentTimeMillis();
+
+ private transient volatile T t;
+
+ private ExpiringCloseableSupplier(Supplier<T> delegate, long
delayCloseInterval) {
+ this.delegate = delegate;
+ this.delayCloseInterval = delayCloseInterval;
+ }
+
+ public synchronized T get() {
+ accessTime = System.currentTimeMillis();
+ if (t == null || t.isClosed()) {
+ this.t = delegate.get();
+ ensureCloseFutureScheduled();
+ }
+ return t;
+ }
+
+ public synchronized void close() {
+ try {
+ if (t != null && !t.isClosed()) {
+ t.close();
+ }
+ } catch (IOException ioe) {
+ LOG.warn("Failed to close {} the resource", t.getClass().getName(), ioe);
+ } finally {
+ this.t = null;
+ this.accessTime = System.currentTimeMillis();
+ cancelCloseFuture();
+ }
+ }
+
+ private void tryClose() {
+ if (System.currentTimeMillis() - accessTime > delayCloseInterval) {
+ close();
+ }
+ }
+
+ private void ensureCloseFutureScheduled() {
+ cancelCloseFuture();
+ this.future =
+ executor.scheduleAtFixedRate(
+ this::tryClose, delayCloseInterval, delayCloseInterval,
TimeUnit.MILLISECONDS);
+ }
+
+ private void cancelCloseFuture() {
+ if (future != null && !future.isDone()) {
+ future.cancel(false);
+ this.future = null;
+ }
+ }
+
+ public static <T extends StatefulCloseable> ExpiringCloseableSupplier<T> of(
+ Supplier<T> delegate) {
+ return new ExpiringCloseableSupplier<>(delegate,
DEFAULT_DELAY_CLOSE_INTERVAL);
+ }
+
+ public static <T extends StatefulCloseable> ExpiringCloseableSupplier<T> of(
+ Supplier<T> delegate, long delayCloseInterval) {
+ return new ExpiringCloseableSupplier<>(delegate, delayCloseInterval);
+ }
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/util/StatefulCloseable.java
b/common/src/main/java/org/apache/uniffle/common/util/StatefulCloseable.java
new file mode 100644
index 000000000..a4a2453d6
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/util/StatefulCloseable.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.util;
+
+import java.io.Closeable;
+
+/** StatefulCloseable is an interface that utilizes the
ExpiringCloseableSupplier delegate. */
+public interface StatefulCloseable extends Closeable {
+ boolean isClosed();
+}
diff --git
a/common/src/test/java/org/apache/uniffle/common/util/ExpiringCloseableSupplierTest.java
b/common/src/test/java/org/apache/uniffle/common/util/ExpiringCloseableSupplierTest.java
new file mode 100644
index 000000000..0f791ceab
--- /dev/null
+++
b/common/src/test/java/org/apache/uniffle/common/util/ExpiringCloseableSupplierTest.java
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.util;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.function.Supplier;
+
+import com.google.common.collect.Lists;
+import com.google.common.util.concurrent.Uninterruptibles;
+import org.apache.commons.lang3.SerializationUtils;
+import org.awaitility.Awaitility;
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertNotSame;
+import static org.junit.jupiter.api.Assertions.assertSame;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+class ExpiringCloseableSupplierTest {
+
+ @Test
+ void testCacheable() {
+ Supplier<MockClient> cf = () -> new MockClient(false);
+ ExpiringCloseableSupplier<MockClient> mockClientSupplier =
ExpiringCloseableSupplier.of(cf);
+
+ MockClient mockClient = mockClientSupplier.get();
+ MockClient mockClient2 = mockClientSupplier.get();
+ assertSame(mockClient, mockClient2);
+ mockClientSupplier.close();
+ mockClientSupplier.close();
+ }
+
+ @Test
+ void testAutoCloseable() {
+ Supplier<MockClient> cf = () -> new MockClient(true);
+ ExpiringCloseableSupplier<MockClient> mockClientSupplier =
ExpiringCloseableSupplier.of(cf, 10);
+ MockClient mockClient1 = mockClientSupplier.get();
+ assertNotNull(mockClient1);
+ Uninterruptibles.sleepUninterruptibly(30, TimeUnit.MILLISECONDS);
+ assertTrue(mockClient1.isClosed());
+ MockClient mockClient2 = mockClientSupplier.get();
+ assertNotSame(mockClient1, mockClient2);
+ mockClientSupplier.close();
+ }
+
+ @Test
+ void testRenew() {
+ Supplier<MockClient> cf = () -> new MockClient(true);
+ ExpiringCloseableSupplier<MockClient> mockClientSupplier =
ExpiringCloseableSupplier.of(cf);
+ MockClient mockClient = mockClientSupplier.get();
+ mockClientSupplier.close();
+ MockClient mockClient2 = mockClientSupplier.get();
+ assertNotSame(mockClient, mockClient2);
+ }
+
+ @Test
+ void testReClose() {
+ Supplier<MockClient> cf = () -> new MockClient(true);
+ ExpiringCloseableSupplier<MockClient> mockClientSupplier =
ExpiringCloseableSupplier.of(cf);
+ mockClientSupplier.get();
+ mockClientSupplier.close();
+ mockClientSupplier.close();
+ }
+
+ @Test
+ void testDelegateExtendClose() throws IOException {
+ Supplier<MockClient> cf = () -> new MockClient(false);
+ ExpiringCloseableSupplier<MockClient> mockClientSupplier =
ExpiringCloseableSupplier.of(cf);
+ MockClient mockClient = mockClientSupplier.get();
+ mockClient.close();
+ assertTrue(mockClient.isClosed());
+
+ MockClient mockClient1 = mockClientSupplier.get();
+ assertNotSame(mockClient, mockClient1);
+ MockClient mockClient2 = mockClientSupplier.get();
+ assertSame(mockClient1, mockClient2);
+ mockClientSupplier.close();
+ }
+
+ @Test
+ public void testSerialization() {
+ Supplier<MockClient> cf = (Supplier<MockClient> & Serializable) () -> new
MockClient(true);
+ ExpiringCloseableSupplier<MockClient> mockClientSupplier =
ExpiringCloseableSupplier.of(cf, 10);
+ MockClient mockClient = mockClientSupplier.get();
+
+ ExpiringCloseableSupplier<MockClient> mockClientSupplier2 =
+ SerializationUtils.roundtrip(mockClientSupplier);
+ MockClient mockClient2 = mockClientSupplier2.get();
+ assertFalse(mockClient2.isClosed());
+ assertNotSame(mockClient, mockClient2);
+ Uninterruptibles.sleepUninterruptibly(30, TimeUnit.MILLISECONDS);
+ assertTrue(mockClient.isClosed());
+ assertTrue(mockClient2.isClosed());
+ }
+
+ @Test
+ public void testMultipleSupplierShouldNotInterfere() {
+ Supplier<MockClient> cf = () -> new MockClient(true);
+ ExpiringCloseableSupplier<MockClient> mockClientSupplier =
ExpiringCloseableSupplier.of(cf, 10);
+ ExpiringCloseableSupplier<MockClient> mockClientSupplier2 =
+ ExpiringCloseableSupplier.of(cf, 10);
+ MockClient mockClient = mockClientSupplier.get();
+ MockClient mockClient2 = mockClientSupplier2.get();
+ Uninterruptibles.sleepUninterruptibly(30, TimeUnit.MILLISECONDS);
+ assertTrue(mockClient.isClosed());
+ assertTrue(mockClient2.isClosed());
+ mockClientSupplier.close();
+ mockClientSupplier.close();
+ mockClientSupplier2.close();
+ mockClientSupplier2.close();
+ }
+
+ @Test
+ public void stressingTestManySuppliers() {
+ int num = 100000; // this should be sufficient for most production use
cases
+ Supplier<MockClient> cf = () -> new MockClient(true);
+ List<MockClient> clients = Lists.newArrayList();
+ Random random = new Random(42);
+ for (int i = 0; i < num; i++) {
+ int delayCloseInterval = random.nextInt(1000) + 1;
+ ExpiringCloseableSupplier<MockClient> mockClientSupplier =
+ ExpiringCloseableSupplier.of(cf, delayCloseInterval);
+ MockClient mockClient = mockClientSupplier.get();
+ clients.add(mockClient);
+ }
+ Awaitility.waitAtMost(5, TimeUnit.SECONDS)
+ .until(() -> clients.stream().allMatch(MockClient::isClosed));
+ }
+
+ private static class MockClient implements StatefulCloseable, Serializable {
+ boolean withException;
+ AtomicBoolean closed = new AtomicBoolean(false);
+
+ MockClient(boolean withException) {
+ this.withException = withException;
+ }
+
+ @Override
+ public void close() throws IOException {
+ closed.set(true);
+ if (withException) {
+ throw new IOException("test exception!");
+ }
+ }
+
+ @Override
+ public boolean isClosed() {
+ return closed.get();
+ }
+ }
+}
diff --git
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/ShuffleServerManagerTestBase.java
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/ShuffleServerManagerTestBase.java
index 831fa0f2f..abe3a9dfa 100644
---
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/ShuffleServerManagerTestBase.java
+++
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/ShuffleServerManagerTestBase.java
@@ -23,6 +23,7 @@ import org.junit.jupiter.api.BeforeEach;
import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
import org.apache.uniffle.client.impl.grpc.ShuffleManagerGrpcClient;
import org.apache.uniffle.common.ClientType;
+import org.apache.uniffle.common.config.RssBaseConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.rpc.GrpcServer;
import org.apache.uniffle.shuffle.manager.DummyRssShuffleManager;
@@ -36,12 +37,17 @@ public class ShuffleServerManagerTestBase {
protected ShuffleManagerGrpcClient client;
protected static final String LOCALHOST = "localhost";
protected GrpcServer shuffleManagerServer;
+ protected RssConf rssConf;
protected RssShuffleManagerInterface getShuffleManager() {
return new DummyRssShuffleManager();
}
- protected RssConf getConf() {
+ protected ShuffleServerManagerTestBase() {
+ this.rssConf = getRssConf();
+ }
+
+ private RssConf getRssConf() {
RssConf conf = new RssConf();
// use a random port
conf.set(RPC_SERVER_PORT, 0);
@@ -49,7 +55,7 @@ public class ShuffleServerManagerTestBase {
}
protected GrpcServer createShuffleManagerServer() {
- return new ShuffleManagerServerFactory(getShuffleManager(),
getConf()).getServer();
+ return new ShuffleManagerServerFactory(getShuffleManager(),
rssConf).getServer();
}
@BeforeEach
@@ -57,7 +63,8 @@ public class ShuffleServerManagerTestBase {
shuffleManagerServer = createShuffleManagerServer();
shuffleManagerServer.start();
int port = shuffleManagerServer.getPort();
- client = factory.createShuffleManagerClient(ClientType.GRPC, LOCALHOST,
port);
+ long rpcTimeout =
rssConf.getLong(RssBaseConf.RSS_CLIENT_TYPE_GRPC_TIMEOUT_MS);
+ client = factory.createShuffleManagerClient(ClientType.GRPC, LOCALHOST,
port, rpcTimeout);
}
@AfterEach
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
index c5b412a9e..6616fe7b1 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
@@ -17,8 +17,6 @@
package org.apache.uniffle.client.api;
-import java.io.Closeable;
-
import
org.apache.uniffle.client.request.RssGetShuffleResultForMultiPartRequest;
import org.apache.uniffle.client.request.RssGetShuffleResultRequest;
import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
@@ -34,8 +32,9 @@ import
org.apache.uniffle.client.response.RssReassignServersResponse;
import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
import org.apache.uniffle.client.response.RssReportShuffleResultResponse;
import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse;
+import org.apache.uniffle.common.util.StatefulCloseable;
-public interface ShuffleManagerClient extends Closeable {
+public interface ShuffleManagerClient extends StatefulCloseable {
RssReportShuffleFetchFailureResponse reportShuffleFetchFailure(
RssReportShuffleFetchFailureRequest request);
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/factory/ShuffleManagerClientFactory.java
b/internal-client/src/main/java/org/apache/uniffle/client/factory/ShuffleManagerClientFactory.java
index c55acdc22..66b4a2a9e 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/factory/ShuffleManagerClientFactory.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/factory/ShuffleManagerClientFactory.java
@@ -33,9 +33,9 @@ public class ShuffleManagerClientFactory {
private ShuffleManagerClientFactory() {}
public ShuffleManagerGrpcClient createShuffleManagerClient(
- ClientType clientType, String host, int port) {
+ ClientType clientType, String host, int port, long rpcTimeout) {
if (ClientType.GRPC.equals(clientType)) {
- return new ShuffleManagerGrpcClient(host, port);
+ return new ShuffleManagerGrpcClient(host, port, rpcTimeout);
} else {
throw new UnsupportedOperationException("Unsupported client type " +
clientType);
}
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
index 6dd9f4a1e..8cad876c2 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
@@ -38,7 +38,6 @@ import
org.apache.uniffle.client.response.RssReassignServersResponse;
import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
import org.apache.uniffle.client.response.RssReportShuffleResultResponse;
import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse;
-import org.apache.uniffle.common.config.RssBaseConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.proto.RssProtos;
import org.apache.uniffle.proto.RssProtos.ReportShuffleFetchFailureRequest;
@@ -48,22 +47,22 @@ import org.apache.uniffle.proto.ShuffleManagerGrpc;
public class ShuffleManagerGrpcClient extends GrpcClient implements
ShuffleManagerClient {
private static final Logger LOG =
LoggerFactory.getLogger(ShuffleManagerGrpcClient.class);
- private static RssBaseConf rssConf = new RssBaseConf();
- private long rpcTimeout =
rssConf.getLong(RssBaseConf.RSS_CLIENT_TYPE_GRPC_TIMEOUT_MS);
+ private final long rpcTimeout;
private ShuffleManagerGrpc.ShuffleManagerBlockingStub blockingStub;
- public ShuffleManagerGrpcClient(String host, int port) {
- this(host, port, 3);
+ public ShuffleManagerGrpcClient(String host, int port, long rpcTimeout) {
+ this(host, port, rpcTimeout, 3);
}
- public ShuffleManagerGrpcClient(String host, int port, int maxRetryAttempts)
{
- this(host, port, maxRetryAttempts, true);
+ public ShuffleManagerGrpcClient(String host, int port, long rpcTimeout, int
maxRetryAttempts) {
+ this(host, port, rpcTimeout, maxRetryAttempts, true);
}
public ShuffleManagerGrpcClient(
- String host, int port, int maxRetryAttempts, boolean usePlaintext) {
+ String host, int port, long rpcTimeout, int maxRetryAttempts, boolean
usePlaintext) {
super(host, port, maxRetryAttempts, usePlaintext);
blockingStub = ShuffleManagerGrpc.newBlockingStub(channel);
+ this.rpcTimeout = rpcTimeout;
}
public ShuffleManagerGrpc.ShuffleManagerBlockingStub getBlockingStub() {
@@ -165,4 +164,9 @@ public class ShuffleManagerGrpcClient extends GrpcClient
implements ShuffleManag
getBlockingStub().reportShuffleResult(request.toProto());
return RssReportShuffleResultResponse.fromProto(response);
}
+
+ @Override
+ public boolean isClosed() {
+ return channel.isShutdown() || channel.isTerminated();
+ }
}
diff --git
a/internal-client/src/test/java/org/apache/uniffle/client/factory/ShuffleManagerClientFactoryTest.java
b/internal-client/src/test/java/org/apache/uniffle/client/factory/ShuffleManagerClientFactoryTest.java
index 5fed54ff0..c40c06c32 100644
---
a/internal-client/src/test/java/org/apache/uniffle/client/factory/ShuffleManagerClientFactoryTest.java
+++
b/internal-client/src/test/java/org/apache/uniffle/client/factory/ShuffleManagerClientFactoryTest.java
@@ -32,10 +32,11 @@ class ShuffleManagerClientFactoryTest {
ShuffleManagerClientFactory factory =
ShuffleManagerClientFactory.getInstance();
assertNotNull(factory);
// only grpc type is supported currently
- ShuffleManagerClient c =
factory.createShuffleManagerClient(ClientType.GRPC, "localhost", 1234);
+ ShuffleManagerClient c =
+ factory.createShuffleManagerClient(ClientType.GRPC, "localhost", 1234,
60000);
assertNotNull(c);
assertThrows(
UnsupportedOperationException.class,
- () -> factory.createShuffleManagerClient(ClientType.GRPC_NETTY,
"localhost", 1234));
+ () -> factory.createShuffleManagerClient(ClientType.GRPC_NETTY,
"localhost", 1234, 60000));
}
}