This is an automated email from the ASF dual-hosted git repository.
xianjingfeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 722d3079a [#1887] improvement: reject all requests from unregistered
apps in shuffle server (#1923)
722d3079a is described below
commit 722d3079aad4a8b0a0b2f3eb9b956658889e0aa2
Author: xianjingfeng <[email protected]>
AuthorDate: Tue Jul 23 15:28:02 2024 +0800
[#1887] improvement: reject all requests from unregistered apps in shuffle
server (#1923)
### What changes were proposed in this pull request?
Reject all requests from unregistered apps in shuffle server
### Why are the changes needed?
For better performance.
Fix: #1887
### Does this PR introduce any user-facing change?
No.
### How was this patch tested?
Existing UT
---
.../apache/uniffle/test/ShuffleServerGrpcTest.java | 4 +-
.../uniffle/server/ShuffleServerGrpcService.java | 164 ++++++++++++++++++---
.../apache/uniffle/server/ShuffleTaskManager.java | 8 +-
.../server/netty/ShuffleServerNettyHandler.java | 46 +++++-
4 files changed, 192 insertions(+), 30 deletions(-)
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
index 8a6e0cecf..df3e29971 100644
---
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
@@ -259,7 +259,7 @@ public class ShuffleServerGrpcTest extends
IntegrationTestBase {
grpcShuffleServerClient.reportShuffleResult(request);
fail("Exception should be thrown");
} catch (Exception e) {
- assertTrue(e.getMessage().contains("error happened when report shuffle
result"));
+ assertTrue(e.getMessage().contains("NO_REGISTER"));
}
RssGetShuffleResultRequest req =
@@ -268,7 +268,7 @@ public class ShuffleServerGrpcTest extends
IntegrationTestBase {
grpcShuffleServerClient.getShuffleResult(req);
fail("Exception should be thrown");
} catch (Exception e) {
- assertTrue(e.getMessage().contains("Can't get shuffle result"));
+ assertTrue(e.getMessage().contains("NO_REGISTER"));
}
RssRegisterShuffleRequest rrsr =
diff --git
a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
index aea43d24e..06e2d8b92 100644
---
a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
+++
b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
@@ -105,19 +105,28 @@ public class ShuffleServerGrpcService extends
ShuffleServerImplBase {
RssProtos.ShuffleUnregisterByAppIdRequest request,
StreamObserver<RssProtos.ShuffleUnregisterByAppIdResponse>
responseStreamObserver) {
String appId = request.getAppId();
-
- StatusCode result = StatusCode.SUCCESS;
+ StatusCode status = verifyRequest(appId);
+ if (status != StatusCode.SUCCESS) {
+ RssProtos.ShuffleUnregisterByAppIdResponse reply =
+ RssProtos.ShuffleUnregisterByAppIdResponse.newBuilder()
+ .setStatus(status.toProto())
+ .setRetMsg(status.toString())
+ .build();
+ responseStreamObserver.onNext(reply);
+ responseStreamObserver.onCompleted();
+ return;
+ }
String responseMessage = "OK";
try {
shuffleServer.getShuffleTaskManager().removeShuffleDataAsync(appId);
} catch (Exception e) {
- result = StatusCode.INTERNAL_ERROR;
+ status = StatusCode.INTERNAL_ERROR;
}
RssProtos.ShuffleUnregisterByAppIdResponse reply =
RssProtos.ShuffleUnregisterByAppIdResponse.newBuilder()
- .setStatus(result.toProto())
+ .setStatus(status.toProto())
.setRetMsg(responseMessage)
.build();
responseStreamObserver.onNext(reply);
@@ -129,19 +138,29 @@ public class ShuffleServerGrpcService extends
ShuffleServerImplBase {
RssProtos.ShuffleUnregisterRequest request,
StreamObserver<RssProtos.ShuffleUnregisterResponse>
responseStreamObserver) {
String appId = request.getAppId();
+ StatusCode status = verifyRequest(appId);
+ if (status != StatusCode.SUCCESS) {
+ RssProtos.ShuffleUnregisterResponse reply =
+ RssProtos.ShuffleUnregisterResponse.newBuilder()
+ .setStatus(status.toProto())
+ .setRetMsg(status.toString())
+ .build();
+ responseStreamObserver.onNext(reply);
+ responseStreamObserver.onCompleted();
+ return;
+ }
int shuffleId = request.getShuffleId();
- StatusCode result = StatusCode.SUCCESS;
String responseMessage = "OK";
try {
shuffleServer.getShuffleTaskManager().removeShuffleDataAsync(appId,
shuffleId);
} catch (Exception e) {
- result = StatusCode.INTERNAL_ERROR;
+ status = StatusCode.INTERNAL_ERROR;
}
RssProtos.ShuffleUnregisterResponse reply =
RssProtos.ShuffleUnregisterResponse.newBuilder()
- .setStatus(result.toProto())
+ .setStatus(status.toProto())
.setRetMsg(responseMessage)
.build();
responseStreamObserver.onNext(reply);
@@ -430,12 +449,20 @@ public class ShuffleServerGrpcService extends
ShuffleServerImplBase {
@Override
public void commitShuffleTask(
ShuffleCommitRequest req, StreamObserver<ShuffleCommitResponse>
responseObserver) {
-
- ShuffleCommitResponse reply;
String appId = req.getAppId();
+ StatusCode status = verifyRequest(appId);
+ if (status != StatusCode.SUCCESS) {
+ ShuffleCommitResponse response =
+ ShuffleCommitResponse.newBuilder()
+ .setStatus(status.toProto())
+ .setRetMsg(status.toString())
+ .build();
+ responseObserver.onNext(response);
+ responseObserver.onCompleted();
+ return;
+ }
int shuffleId = req.getShuffleId();
- StatusCode status = StatusCode.SUCCESS;
String msg = "OK";
int commitCount = 0;
@@ -460,7 +487,7 @@ public class ShuffleServerGrpcService extends
ShuffleServerImplBase {
LOG.error(msg, e);
}
- reply =
+ ShuffleCommitResponse reply =
ShuffleCommitResponse.newBuilder()
.setCommitCount(commitCount)
.setStatus(status.toProto())
@@ -474,8 +501,18 @@ public class ShuffleServerGrpcService extends
ShuffleServerImplBase {
public void finishShuffle(
FinishShuffleRequest req, StreamObserver<FinishShuffleResponse>
responseObserver) {
String appId = req.getAppId();
+ StatusCode status = verifyRequest(appId);
+ if (status != StatusCode.SUCCESS) {
+ FinishShuffleResponse response =
+ FinishShuffleResponse.newBuilder()
+ .setStatus(status.toProto())
+ .setRetMsg(status.toString())
+ .build();
+ responseObserver.onNext(response);
+ responseObserver.onCompleted();
+ return;
+ }
int shuffleId = req.getShuffleId();
- StatusCode status;
String msg = "OK";
String errorMsg =
"Fail to finish shuffle for appId["
@@ -506,8 +543,18 @@ public class ShuffleServerGrpcService extends
ShuffleServerImplBase {
public void requireBuffer(
RequireBufferRequest request, StreamObserver<RequireBufferResponse>
responseObserver) {
String appId = request.getAppId();
+ StatusCode status = verifyRequest(appId);
+ if (status != StatusCode.SUCCESS) {
+ RequireBufferResponse response =
+ RequireBufferResponse.newBuilder()
+ .setStatus(status.toProto())
+ .setRetMsg(status.toString())
+ .build();
+ responseObserver.onNext(response);
+ responseObserver.onCompleted();
+ return;
+ }
long requireBufferId = -1;
- StatusCode status = StatusCode.SUCCESS;
try {
if (StringUtils.isEmpty(appId)) {
// To be compatible with older client version
@@ -548,6 +595,17 @@ public class ShuffleServerGrpcService extends
ShuffleServerImplBase {
public void appHeartbeat(
AppHeartBeatRequest request, StreamObserver<AppHeartBeatResponse>
responseObserver) {
String appId = request.getAppId();
+ StatusCode status = verifyRequest(appId);
+ if (status != StatusCode.SUCCESS) {
+ AppHeartBeatResponse response =
+ AppHeartBeatResponse.newBuilder()
+ .setStatus(status.toProto())
+ .setRetMsg(status.toString())
+ .build();
+ responseObserver.onNext(response);
+ responseObserver.onCompleted();
+ return;
+ }
shuffleServer.getShuffleTaskManager().refreshAppId(appId);
AppHeartBeatResponse response =
AppHeartBeatResponse.newBuilder()
@@ -572,12 +630,22 @@ public class ShuffleServerGrpcService extends
ShuffleServerImplBase {
ReportShuffleResultRequest request,
StreamObserver<ReportShuffleResultResponse> responseObserver) {
String appId = request.getAppId();
+ StatusCode status = verifyRequest(appId);
+ if (status != StatusCode.SUCCESS) {
+ ReportShuffleResultResponse response =
+ ReportShuffleResultResponse.newBuilder()
+ .setStatus(status.toProto())
+ .setRetMsg(status.toString())
+ .build();
+ responseObserver.onNext(response);
+ responseObserver.onCompleted();
+ return;
+ }
int shuffleId = request.getShuffleId();
long taskAttemptId = request.getTaskAttemptId();
int bitmapNum = request.getBitmapNum();
Map<Integer, long[]> partitionToBlockIds =
toPartitionBlocksMap(request.getPartitionToBlockIdsList());
- StatusCode status = StatusCode.SUCCESS;
String msg = "OK";
ReportShuffleResultResponse reply;
String requestInfo =
@@ -617,6 +685,17 @@ public class ShuffleServerGrpcService extends
ShuffleServerImplBase {
public void getShuffleResult(
GetShuffleResultRequest request,
StreamObserver<GetShuffleResultResponse> responseObserver) {
String appId = request.getAppId();
+ StatusCode status = verifyRequest(appId);
+ if (status != StatusCode.SUCCESS) {
+ GetShuffleResultResponse response =
+ GetShuffleResultResponse.newBuilder()
+ .setStatus(status.toProto())
+ .setRetMsg(status.toString())
+ .build();
+ responseObserver.onNext(response);
+ responseObserver.onCompleted();
+ return;
+ }
int shuffleId = request.getShuffleId();
int partitionId = request.getPartitionId();
BlockIdLayout blockIdLayout =
@@ -624,7 +703,6 @@ public class ShuffleServerGrpcService extends
ShuffleServerImplBase {
request.getBlockIdLayout().getSequenceNoBits(),
request.getBlockIdLayout().getPartitionIdBits(),
request.getBlockIdLayout().getTaskAttemptIdBits());
- StatusCode status = StatusCode.SUCCESS;
String msg = "OK";
GetShuffleResultResponse reply;
byte[] serializedBlockIds = null;
@@ -665,6 +743,17 @@ public class ShuffleServerGrpcService extends
ShuffleServerImplBase {
GetShuffleResultForMultiPartRequest request,
StreamObserver<GetShuffleResultForMultiPartResponse> responseObserver) {
String appId = request.getAppId();
+ StatusCode status = verifyRequest(appId);
+ if (status != StatusCode.SUCCESS) {
+ GetShuffleResultForMultiPartResponse response =
+ GetShuffleResultForMultiPartResponse.newBuilder()
+ .setStatus(status.toProto())
+ .setRetMsg(status.toString())
+ .build();
+ responseObserver.onNext(response);
+ responseObserver.onCompleted();
+ return;
+ }
int shuffleId = request.getShuffleId();
List<Integer> partitionsList = request.getPartitionsList();
BlockIdLayout blockIdLayout =
@@ -673,7 +762,6 @@ public class ShuffleServerGrpcService extends
ShuffleServerImplBase {
request.getBlockIdLayout().getPartitionIdBits(),
request.getBlockIdLayout().getTaskAttemptIdBits());
- StatusCode status = StatusCode.SUCCESS;
String msg = "OK";
GetShuffleResultForMultiPartResponse reply;
byte[] serializedBlockIds = null;
@@ -715,6 +803,17 @@ public class ShuffleServerGrpcService extends
ShuffleServerImplBase {
GetLocalShuffleDataRequest request,
StreamObserver<GetLocalShuffleDataResponse> responseObserver) {
String appId = request.getAppId();
+ StatusCode status = verifyRequest(appId);
+ if (status != StatusCode.SUCCESS) {
+ GetLocalShuffleDataResponse response =
+ GetLocalShuffleDataResponse.newBuilder()
+ .setStatus(status.toProto())
+ .setRetMsg(status.toString())
+ .build();
+ responseObserver.onNext(response);
+ responseObserver.onCompleted();
+ return;
+ }
int shuffleId = request.getShuffleId();
int partitionId = request.getPartitionId();
int partitionNumPerRange = request.getPartitionNumPerRange();
@@ -732,7 +831,6 @@ public class ShuffleServerGrpcService extends
ShuffleServerImplBase {
}
String storageType =
shuffleServer.getShuffleServerConf().get(RssBaseConf.RSS_STORAGE_TYPE).name();
- StatusCode status = StatusCode.SUCCESS;
String msg = "OK";
GetLocalShuffleDataResponse reply = null;
ShuffleDataResult sdr = null;
@@ -831,11 +929,21 @@ public class ShuffleServerGrpcService extends
ShuffleServerImplBase {
GetLocalShuffleIndexRequest request,
StreamObserver<GetLocalShuffleIndexResponse> responseObserver) {
String appId = request.getAppId();
+ StatusCode status = verifyRequest(appId);
+ if (status != StatusCode.SUCCESS) {
+ GetLocalShuffleIndexResponse reply =
+ GetLocalShuffleIndexResponse.newBuilder()
+ .setStatus(status.toProto())
+ .setRetMsg(status.toString())
+ .build();
+ responseObserver.onNext(reply);
+ responseObserver.onCompleted();
+ return;
+ }
int shuffleId = request.getShuffleId();
int partitionId = request.getPartitionId();
int partitionNumPerRange = request.getPartitionNumPerRange();
int partitionNum = request.getPartitionNum();
- StatusCode status = StatusCode.SUCCESS;
String msg = "OK";
GetLocalShuffleIndexResponse reply;
String requestInfo =
@@ -928,6 +1036,17 @@ public class ShuffleServerGrpcService extends
ShuffleServerImplBase {
GetMemoryShuffleDataRequest request,
StreamObserver<GetMemoryShuffleDataResponse> responseObserver) {
String appId = request.getAppId();
+ StatusCode status = verifyRequest(appId);
+ if (status != StatusCode.SUCCESS) {
+ GetMemoryShuffleDataResponse reply =
+ GetMemoryShuffleDataResponse.newBuilder()
+ .setStatus(status.toProto())
+ .setRetMsg(status.toString())
+ .build();
+ responseObserver.onNext(reply);
+ responseObserver.onCompleted();
+ return;
+ }
int shuffleId = request.getShuffleId();
int partitionId = request.getPartitionId();
long blockId = request.getLastBlockId();
@@ -943,7 +1062,6 @@ public class ShuffleServerGrpcService extends
ShuffleServerImplBase {
ShuffleServerGrpcMetrics.GET_MEMORY_SHUFFLE_DATA_METHOD,
transportTime);
}
}
- StatusCode status = StatusCode.SUCCESS;
String msg = "OK";
GetMemoryShuffleDataResponse reply;
String requestInfo =
@@ -1108,4 +1226,12 @@ public class ShuffleServerGrpcService extends
ShuffleServerImplBase {
}
return shuffleDataBlockSegments;
}
+
+ private StatusCode verifyRequest(String appId) {
+ if (StringUtils.isNotBlank(appId)
+ && shuffleServer.getShuffleTaskManager().isAppExpired(appId)) {
+ return StatusCode.NO_REGISTER;
+ }
+ return StatusCode.SUCCESS;
+ }
}
diff --git
a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
index 8fe597d03..b258c8a11 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
@@ -725,12 +725,12 @@ public class ShuffleTaskManager {
}
}
- private boolean isAppExpired(String appId) {
- if (shuffleTaskInfos.get(appId) == null) {
+ public boolean isAppExpired(String appId) {
+ ShuffleTaskInfo shuffleTaskInfo = shuffleTaskInfos.get(appId);
+ if (shuffleTaskInfo == null) {
return true;
}
- return System.currentTimeMillis() -
shuffleTaskInfos.get(appId).getCurrentTimes()
- > appExpiredWithoutHB;
+ return System.currentTimeMillis() - shuffleTaskInfo.getCurrentTimes() >
appExpiredWithoutHB;
}
/**
diff --git
a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
index 27a3f1dc4..cca6a3935 100644
---
a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
+++
b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
@@ -28,6 +28,7 @@ import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -69,7 +70,6 @@ import org.apache.uniffle.storage.util.ShuffleStorageUtils;
public class ShuffleServerNettyHandler implements BaseMessageHandler {
private static final Logger LOG =
LoggerFactory.getLogger(ShuffleServerNettyHandler.class);
- private static final int RPC_TIMEOUT = 60000;
private final ShuffleServer shuffleServer;
public ShuffleServerNettyHandler(ShuffleServer shuffleServer) {
@@ -335,6 +335,18 @@ public class ShuffleServerNettyHandler implements
BaseMessageHandler {
public void handleGetMemoryShuffleDataRequest(
TransportClient client, GetMemoryShuffleDataRequest req) {
String appId = req.getAppId();
+ StatusCode status = verifyRequest(appId);
+ if (status != StatusCode.SUCCESS) {
+ GetMemoryShuffleDataResponse response =
+ new GetMemoryShuffleDataResponse(
+ req.getRequestId(),
+ status,
+ status.toString(),
+ Lists.newArrayList(),
+ Unpooled.EMPTY_BUFFER);
+ client.getChannel().writeAndFlush(response);
+ return;
+ }
int shuffleId = req.getShuffleId();
int partitionId = req.getPartitionId();
long blockId = req.getLastBlockId();
@@ -349,7 +361,6 @@ public class ShuffleServerNettyHandler implements
BaseMessageHandler {
.recordTransportTime(GetMemoryShuffleDataRequest.class.getName(),
transportTime);
}
}
- StatusCode status = StatusCode.SUCCESS;
String msg = "OK";
GetMemoryShuffleDataResponse response;
String requestInfo =
@@ -417,11 +428,18 @@ public class ShuffleServerNettyHandler implements
BaseMessageHandler {
public void handleGetLocalShuffleIndexRequest(
TransportClient client, GetLocalShuffleIndexRequest req) {
String appId = req.getAppId();
+ StatusCode status = verifyRequest(appId);
+ if (status != StatusCode.SUCCESS) {
+ GetLocalShuffleIndexResponse response =
+ new GetLocalShuffleIndexResponse(
+ req.getRequestId(), status, status.toString(),
Unpooled.EMPTY_BUFFER, 0L);
+ client.getChannel().writeAndFlush(response);
+ return;
+ }
int shuffleId = req.getShuffleId();
int partitionId = req.getPartitionId();
int partitionNumPerRange = req.getPartitionNumPerRange();
int partitionNum = req.getPartitionNum();
- StatusCode status = StatusCode.SUCCESS;
String msg = "OK";
GetLocalShuffleIndexResponse response;
String requestInfo =
@@ -501,7 +519,19 @@ public class ShuffleServerNettyHandler implements
BaseMessageHandler {
}
public void handleGetLocalShuffleData(TransportClient client,
GetLocalShuffleDataRequest req) {
+ GetLocalShuffleDataResponse response;
String appId = req.getAppId();
+ StatusCode status = verifyRequest(appId);
+ if (status != StatusCode.SUCCESS) {
+ response =
+ new GetLocalShuffleDataResponse(
+ req.getRequestId(),
+ status,
+ status.toString(),
+ new NettyManagedBuffer(Unpooled.EMPTY_BUFFER));
+ client.getChannel().writeAndFlush(response);
+ return;
+ }
int shuffleId = req.getShuffleId();
int partitionId = req.getPartitionId();
int partitionNumPerRange = req.getPartitionNumPerRange();
@@ -519,9 +549,7 @@ public class ShuffleServerNettyHandler implements
BaseMessageHandler {
}
String storageType =
shuffleServer.getShuffleServerConf().get(RssBaseConf.RSS_STORAGE_TYPE).name();
- StatusCode status = StatusCode.SUCCESS;
String msg = "OK";
- GetLocalShuffleDataResponse response;
String requestInfo =
"appId["
+ appId
@@ -625,6 +653,14 @@ public class ShuffleServerNettyHandler implements
BaseMessageHandler {
return ret;
}
+ private StatusCode verifyRequest(String appId) {
+ if (StringUtils.isNotBlank(appId)
+ && shuffleServer.getShuffleTaskManager().isAppExpired(appId)) {
+ return StatusCode.NO_REGISTER;
+ }
+ return StatusCode.SUCCESS;
+ }
+
class ReleaseMemoryAndRecordReadTimeListener implements
ChannelFutureListener {
private final long readStartedTime;
private final long readBufferSize;