This is an automated email from the ASF dual-hosted git repository.
szetszwo pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ratis.git
The following commit(s) were added to refs/heads/master by this push:
new 84285d3ea RATIS-1925. Support Zero-Copy in GrpcClientProtocolService
(#1007)
84285d3ea is described below
commit 84285d3eadca2f734157d2970fc2c5cce02331e1
Author: Duong Nguyen <[email protected]>
AuthorDate: Fri Jan 12 19:00:53 2024 -0800
RATIS-1925. Support Zero-Copy in GrpcClientProtocolService (#1007)
---
.../apache/ratis/protocol/RaftClientRequest.java | 10 +-
.../apache/ratis/util/ReferenceCountedObject.java | 21 ++++
.../main/java/org/apache/ratis/grpc/GrpcUtil.java | 26 ++++
.../apache/ratis/grpc/metrics/ZeroCopyMetrics.java | 58 +++++++++
.../grpc/server/GrpcClientProtocolService.java | 134 +++++++++++++--------
.../org/apache/ratis/grpc/server/GrpcService.java | 8 +-
.../ratis/grpc/util/ZeroCopyMessageMarshaller.java | 8 +-
.../apache/ratis/server/impl/LeaderStateImpl.java | 12 +-
.../ratis/server/impl/MessageStreamRequests.java | 45 +++++--
.../apache/ratis/server/impl/RaftServerImpl.java | 9 +-
.../impl/SimpleStateMachine4Testing.java | 2 +-
.../ratis/grpc/util/GrpcZeroCopyTestServer.java | 4 +-
12 files changed, 258 insertions(+), 79 deletions(-)
diff --git
a/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientRequest.java
b/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientRequest.java
index ed41f1ea2..18c157130 100644
---
a/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientRequest.java
+++
b/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientRequest.java
@@ -474,7 +474,13 @@ public class RaftClientRequest extends RaftClientMessage {
@Override
public String toString() {
- return super.toString() + ", seq=" +
ProtoUtils.toString(slidingWindowEntry) + ", "
- + type + ", " + getMessage();
+ return toStringShort() + ", " + getMessage();
+ }
+
+ /**
+ * @return a short string which does not include {@link #message}.
+ */
+ public String toStringShort() {
+ return super.toString() + ", seq=" +
ProtoUtils.toString(slidingWindowEntry) + ", " + type;
}
}
diff --git
a/ratis-common/src/main/java/org/apache/ratis/util/ReferenceCountedObject.java
b/ratis-common/src/main/java/org/apache/ratis/util/ReferenceCountedObject.java
index 3f72f5ffe..815b90dbc 100644
---
a/ratis-common/src/main/java/org/apache/ratis/util/ReferenceCountedObject.java
+++
b/ratis-common/src/main/java/org/apache/ratis/util/ReferenceCountedObject.java
@@ -19,6 +19,7 @@ package org.apache.ratis.util;
import org.apache.ratis.util.function.UncheckedAutoCloseableSupplier;
+import java.util.Collection;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
@@ -101,6 +102,26 @@ public interface ReferenceCountedObject<T> {
return wrap(value, () -> {}, ignored -> {});
}
+ static <T, V> ReferenceCountedObject<V>
delegateFrom(Collection<ReferenceCountedObject<T>> fromRefs, V value) {
+ return new ReferenceCountedObject<V>() {
+ @Override
+ public V get() {
+ return value;
+ }
+
+ @Override
+ public V retain() {
+ fromRefs.forEach(ReferenceCountedObject::retain);
+ return value;
+ }
+
+ @Override
+ public boolean release() {
+ return
fromRefs.stream().map(ReferenceCountedObject::release).allMatch(r -> r);
+ }
+ };
+ }
+
/**
* @return a {@link ReferenceCountedObject} of the given value by delegating
to this object.
*/
diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcUtil.java
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcUtil.java
index 22653b6ef..0baefa2d3 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcUtil.java
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcUtil.java
@@ -24,8 +24,12 @@ import org.apache.ratis.security.TlsConf.TrustManagerConf;
import org.apache.ratis.security.TlsConf.CertificatesConf;
import org.apache.ratis.security.TlsConf.PrivateKeyConf;
import org.apache.ratis.security.TlsConf.KeyManagerConf;
+import org.apache.ratis.thirdparty.com.google.protobuf.MessageLite;
import org.apache.ratis.thirdparty.io.grpc.ManagedChannel;
import org.apache.ratis.thirdparty.io.grpc.Metadata;
+import org.apache.ratis.thirdparty.io.grpc.MethodDescriptor;
+import org.apache.ratis.thirdparty.io.grpc.ServerCallHandler;
+import org.apache.ratis.thirdparty.io.grpc.ServerServiceDefinition;
import org.apache.ratis.thirdparty.io.grpc.Status;
import org.apache.ratis.thirdparty.io.grpc.StatusRuntimeException;
import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver;
@@ -304,4 +308,26 @@ public interface GrpcUtil {
b.keyManager(privateKey.get(), certificates.get());
}
}
+
+ /**
+ * Used to add a method to Service definition with a custom request
marshaller.
+ *
+ * @param orig original service definition.
+ * @param newServiceBuilder builder of the new service definition.
+ * @param origMethod the original method definition.
+ * @param customMarshaller custom marshaller to be set for the method.
+ * @param <Req>
+ * @param <Resp>
+ */
+ static <Req extends MessageLite, Resp> void addMethodWithCustomMarshaller(
+ ServerServiceDefinition orig, ServerServiceDefinition.Builder
newServiceBuilder,
+ MethodDescriptor<Req, Resp> origMethod,
MethodDescriptor.PrototypeMarshaller<Req> customMarshaller) {
+ MethodDescriptor<Req, Resp> newMethod = origMethod.toBuilder()
+ .setRequestMarshaller(customMarshaller)
+ .build();
+ @SuppressWarnings("unchecked")
+ ServerCallHandler<Req, Resp> serverCallHandler =
+ (ServerCallHandler<Req, Resp>)
orig.getMethod(newMethod.getFullMethodName()).getServerCallHandler();
+ newServiceBuilder.addMethod(newMethod, serverCallHandler);
+ }
}
diff --git
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/metrics/ZeroCopyMetrics.java
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/metrics/ZeroCopyMetrics.java
new file mode 100644
index 000000000..20da4ee63
--- /dev/null
+++
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/metrics/ZeroCopyMetrics.java
@@ -0,0 +1,58 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ratis.grpc.metrics;
+
+import org.apache.ratis.metrics.LongCounter;
+import org.apache.ratis.metrics.MetricRegistryInfo;
+import org.apache.ratis.metrics.RatisMetricRegistry;
+import org.apache.ratis.metrics.RatisMetrics;
+import org.apache.ratis.thirdparty.com.google.protobuf.AbstractMessage;
+
+public class ZeroCopyMetrics extends RatisMetrics {
+ private static final String RATIS_GRPC_METRICS_APP_NAME = "ratis_grpc";
+ private static final String RATIS_GRPC_METRICS_COMP_NAME = "zero_copy";
+ private static final String RATIS_GRPC_METRICS_DESC = "Metrics for Ratis
Grpc Zero copy";
+
+ private final LongCounter zeroCopyMessages =
getRegistry().counter("num_zero_copy_messages");
+ private final LongCounter nonZeroCopyMessages =
getRegistry().counter("num_non_zero_copy_messages");
+ private final LongCounter releasedMessages =
getRegistry().counter("num_released_messages");
+
+ public ZeroCopyMetrics() {
+ super(createRegistry());
+ }
+
+ private static RatisMetricRegistry createRegistry() {
+ return create(new MetricRegistryInfo("",
+ RATIS_GRPC_METRICS_APP_NAME,
+ RATIS_GRPC_METRICS_COMP_NAME, RATIS_GRPC_METRICS_DESC));
+ }
+
+
+ public void onZeroCopyMessage(AbstractMessage ignored) {
+ zeroCopyMessages.inc();
+ }
+
+ public void onNonZeroCopyMessage(AbstractMessage ignored) {
+ nonZeroCopyMessages.inc();
+ }
+
+ public void onReleasedMessage(AbstractMessage ignored) {
+ releasedMessages.inc();
+ }
+
+}
\ No newline at end of file
diff --git
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcClientProtocolService.java
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcClientProtocolService.java
index 9c1968467..e8de4def0 100644
---
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcClientProtocolService.java
+++
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcClientProtocolService.java
@@ -19,10 +19,13 @@ package org.apache.ratis.grpc.server;
import org.apache.ratis.client.impl.ClientProtoUtils;
import org.apache.ratis.grpc.GrpcUtil;
+import org.apache.ratis.grpc.metrics.ZeroCopyMetrics;
+import org.apache.ratis.grpc.util.ZeroCopyMessageMarshaller;
import org.apache.ratis.protocol.*;
import org.apache.ratis.protocol.exceptions.AlreadyClosedException;
import org.apache.ratis.protocol.exceptions.GroupMismatchException;
import org.apache.ratis.protocol.exceptions.RaftException;
+import org.apache.ratis.thirdparty.io.grpc.ServerServiceDefinition;
import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver;
import org.apache.ratis.proto.RaftProtos.RaftClientReplyProto;
import org.apache.ratis.proto.RaftProtos.RaftClientRequestProto;
@@ -30,16 +33,15 @@ import
org.apache.ratis.proto.grpc.RaftClientProtocolServiceGrpc.RaftClientProto
import org.apache.ratis.util.CollectionUtils;
import org.apache.ratis.util.JavaUtils;
import org.apache.ratis.util.Preconditions;
+import org.apache.ratis.util.ReferenceCountedObject;
import org.apache.ratis.util.SlidingWindow;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
@@ -48,15 +50,21 @@ import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Supplier;
+import static org.apache.ratis.grpc.GrpcUtil.addMethodWithCustomMarshaller;
+import static
org.apache.ratis.proto.grpc.RaftClientProtocolServiceGrpc.getOrderedMethod;
+import static
org.apache.ratis.proto.grpc.RaftClientProtocolServiceGrpc.getUnorderedMethod;
+
class GrpcClientProtocolService extends RaftClientProtocolServiceImplBase {
private static final Logger LOG =
LoggerFactory.getLogger(GrpcClientProtocolService.class);
private static class PendingOrderedRequest implements
SlidingWindow.ServerSideRequest<RaftClientReply> {
+ private final ReferenceCountedObject<RaftClientRequest> requestRef;
private final RaftClientRequest request;
private final AtomicReference<RaftClientReply> reply = new
AtomicReference<>();
- PendingOrderedRequest(RaftClientRequest request) {
- this.request = request;
+ PendingOrderedRequest(ReferenceCountedObject<RaftClientRequest>
requestRef) {
+ this.requestRef = requestRef;
+ this.request = requestRef != null ? requestRef.get() : null;
}
@Override
@@ -76,15 +84,16 @@ class GrpcClientProtocolService extends
RaftClientProtocolServiceImplBase {
@Override
public void setReply(RaftClientReply r) {
final boolean set = reply.compareAndSet(null, r);
- Preconditions.assertTrue(set, () -> "Reply is already set: request=" +
request + ", reply=" + reply);
+ Preconditions.assertTrue(set, () -> "Reply is already set: request=" +
+ request.toStringShort() + ", reply=" + reply);
}
RaftClientReply getReply() {
return reply.get();
}
- RaftClientRequest getRequest() {
- return request;
+ ReferenceCountedObject<RaftClientRequest> getRequestRef() {
+ return requestRef;
}
@Override
@@ -135,18 +144,31 @@ class GrpcClientProtocolService extends
RaftClientProtocolServiceImplBase {
private final ExecutorService executor;
private final OrderedStreamObservers orderedStreamObservers = new
OrderedStreamObservers();
+ private final ZeroCopyMessageMarshaller<RaftClientRequestProto>
zeroCopyRequestMarshaller;
GrpcClientProtocolService(Supplier<RaftPeerId> idSupplier,
RaftClientAsynchronousProtocol protocol,
- ExecutorService executor) {
+ ExecutorService executor, ZeroCopyMetrics zeroCopyMetrics) {
this.idSupplier = idSupplier;
this.protocol = protocol;
this.executor = executor;
+ this.zeroCopyRequestMarshaller = new
ZeroCopyMessageMarshaller<>(RaftClientRequestProto.getDefaultInstance(),
+ zeroCopyMetrics::onZeroCopyMessage,
zeroCopyMetrics::onNonZeroCopyMessage, zeroCopyMetrics::onReleasedMessage);
}
RaftPeerId getId() {
return idSupplier.get();
}
+ ServerServiceDefinition bindServiceWithZeroCopy() {
+ ServerServiceDefinition orig = super.bindService();
+ ServerServiceDefinition.Builder builder =
ServerServiceDefinition.builder(orig.getServiceDescriptor().getName());
+
+ addMethodWithCustomMarshaller(orig, builder, getOrderedMethod(),
zeroCopyRequestMarshaller);
+ addMethodWithCustomMarshaller(orig, builder, getUnorderedMethod(),
zeroCopyRequestMarshaller);
+
+ return builder.build();
+ }
+
@Override
public StreamObserver<RaftClientRequestProto>
ordered(StreamObserver<RaftClientReplyProto> responseObserver) {
final OrderedRequestStreamObserver so = new
OrderedRequestStreamObserver(responseObserver);
@@ -220,31 +242,38 @@ class GrpcClientProtocolService extends
RaftClientProtocolServiceImplBase {
return isClosed.get();
}
- CompletableFuture<Void> processClientRequest(RaftClientRequest request,
Consumer<RaftClientReply> replyHandler) {
- try {
- final String errMsg = LOG.isDebugEnabled() ? "processClientRequest for
" + request : "";
- return protocol.submitClientRequestAsync(request
- ).thenAcceptAsync(replyHandler, executor
- ).exceptionally(exception -> {
- // TODO: the exception may be from either raft or state machine.
- // Currently we skip all the following responses when getting an
- // exception from the state machine.
- responseError(exception, () -> errMsg);
- return null;
- });
- } catch (IOException e) {
- throw new CompletionException("Failed processClientRequest for " +
request + " in " + name, e);
- }
+ CompletableFuture<Void>
processClientRequest(ReferenceCountedObject<RaftClientRequest> requestRef,
+ Consumer<RaftClientReply> replyHandler) {
+ final String errMsg = LOG.isDebugEnabled() ? "processClientRequest for "
+ requestRef.get() : "";
+ return protocol.submitClientRequestAsync(requestRef
+ ).thenAcceptAsync(replyHandler, executor
+ ).exceptionally(exception -> {
+ // TODO: the exception may be from either raft or state machine.
+ // Currently we skip all the following responses when getting an
+ // exception from the state machine.
+ responseError(exception, () -> errMsg);
+ return null;
+ });
}
- abstract void processClientRequest(RaftClientRequest request);
+ abstract void
processClientRequest(ReferenceCountedObject<RaftClientRequest> requestRef);
@Override
public void onNext(RaftClientRequestProto request) {
+ ReferenceCountedObject<RaftClientRequest> requestRef = null;
try {
final RaftClientRequest r =
ClientProtoUtils.toRaftClientRequest(request);
- processClientRequest(r);
+ requestRef = ReferenceCountedObject.wrap(r, () -> {}, released -> {
+ if (released) {
+ zeroCopyRequestMarshaller.release(request);
+ }
+ });
+
+ processClientRequest(requestRef);
} catch (Exception e) {
+ if (requestRef == null) {
+ zeroCopyRequestMarshaller.release(request);
+ }
responseError(e, () -> "onNext for " +
ClientProtoUtils.toString(request) + " in " + name);
}
}
@@ -278,15 +307,18 @@ class GrpcClientProtocolService extends
RaftClientProtocolServiceImplBase {
}
@Override
- void processClientRequest(RaftClientRequest request) {
- final CompletableFuture<Void> f = processClientRequest(request, reply ->
{
+ void processClientRequest(ReferenceCountedObject<RaftClientRequest>
requestRef) {
+ final RaftClientRequest request = requestRef.retain();
+ final long callId = request.getCallId();
+
+ final CompletableFuture<Void> f = processClientRequest(requestRef, reply
-> {
if (!reply.isSuccess()) {
- LOG.info("Failed " + request + ", reply=" + reply);
+ LOG.info("Failed {}, reply={}", request, reply);
}
final RaftClientReplyProto proto =
ClientProtoUtils.toRaftClientReplyProto(reply);
responseNext(proto);
- });
- final long callId = request.getCallId();
+ }).whenComplete((r, e) -> requestRef.release());
+
put(callId, f);
f.thenAccept(dummy -> remove(callId));
}
@@ -329,31 +361,35 @@ class GrpcClientProtocolService extends
RaftClientProtocolServiceImplBase {
void processClientRequest(PendingOrderedRequest pending) {
final long seq = pending.getSeqNum();
- processClientRequest(pending.getRequest(),
+ processClientRequest(pending.getRequestRef(),
reply -> slidingWindow.receiveReply(seq, reply, this::sendReply));
}
@Override
- void processClientRequest(RaftClientRequest r) {
- if (isClosed()) {
- final AlreadyClosedException exception = new
AlreadyClosedException(getName() + ": the stream is closed");
- responseError(exception, () -> "processClientRequest (stream already
closed) for " + r);
- }
+ void processClientRequest(ReferenceCountedObject<RaftClientRequest>
requestRef) {
+ final RaftClientRequest request = requestRef.retain();
+ try {
+ if (isClosed()) {
+ final AlreadyClosedException exception = new
AlreadyClosedException(getName() + ": the stream is closed");
+ responseError(exception, () -> "processClientRequest (stream already
closed) for " + request);
+ }
- final RaftGroupId requestGroupId = r.getRaftGroupId();
- // use the group id in the first request as the group id of this observer
- final RaftGroupId updated = groupId.updateAndGet(g -> g != null ? g:
requestGroupId);
- final PendingOrderedRequest pending = new PendingOrderedRequest(r);
-
- if (!requestGroupId.equals(updated)) {
- final GroupMismatchException exception = new
GroupMismatchException(getId()
- + ": The group (" + requestGroupId + ") of " + r.getClientId()
- + " does not match the group (" + updated + ") of the " +
JavaUtils.getClassSimpleName(getClass()));
- responseError(exception, () -> "processClientRequest (Group
mismatched) for " + r);
- return;
+ final RaftGroupId requestGroupId = request.getRaftGroupId();
+ // use the group id in the first request as the group id of this
observer
+ final RaftGroupId updated = groupId.updateAndGet(g -> g != null ? g :
requestGroupId);
+ final PendingOrderedRequest pending = new
PendingOrderedRequest(requestRef);
+
+ if (!requestGroupId.equals(updated)) {
+ final GroupMismatchException exception = new
GroupMismatchException(getId()
+ + ": The group (" + requestGroupId + ") of " +
request.getClientId()
+ + " does not match the group (" + updated + ") of the " +
JavaUtils.getClassSimpleName(getClass()));
+ responseError(exception, () -> "processClientRequest (Group
mismatched) for " + request);
+ return;
+ }
+ slidingWindow.receivedRequest(pending, this::processClientRequest);
+ } finally {
+ requestRef.release();
}
-
- slidingWindow.receivedRequest(pending, this::processClientRequest);
}
private void sendReply(PendingOrderedRequest ready) {
diff --git
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcService.java
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcService.java
index 097900a0f..d89afd565 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcService.java
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcService.java
@@ -21,6 +21,7 @@ import org.apache.ratis.conf.RaftProperties;
import org.apache.ratis.grpc.GrpcConfigKeys;
import org.apache.ratis.grpc.GrpcTlsConfig;
import org.apache.ratis.grpc.GrpcUtil;
+import org.apache.ratis.grpc.metrics.ZeroCopyMetrics;
import org.apache.ratis.grpc.metrics.intercept.server.MetricServerInterceptor;
import org.apache.ratis.protocol.RaftGroupId;
import org.apache.ratis.protocol.RaftPeerId;
@@ -153,6 +154,7 @@ public final class GrpcService extends
RaftServerRpcWithProxy<GrpcServerProtocol
private final GrpcClientProtocolService clientProtocolService;
private final MetricServerInterceptor serverInterceptor;
+ private final ZeroCopyMetrics zeroCopyMetrics;
public MetricServerInterceptor getServerInterceptor() {
return serverInterceptor;
@@ -199,7 +201,8 @@ public final class GrpcService extends
RaftServerRpcWithProxy<GrpcServerProtocol
GrpcConfigKeys.Server.asyncRequestThreadPoolCached(properties),
GrpcConfigKeys.Server.asyncRequestThreadPoolSize(properties),
getId() + "-request-");
- this.clientProtocolService = new GrpcClientProtocolService(idSupplier,
raftServer, executor);
+ this.zeroCopyMetrics = new ZeroCopyMetrics();
+ this.clientProtocolService = new GrpcClientProtocolService(idSupplier,
raftServer, executor, zeroCopyMetrics);
this.serverInterceptor = new MetricServerInterceptor(
idSupplier,
@@ -252,7 +255,8 @@ public final class GrpcService extends
RaftServerRpcWithProxy<GrpcServerProtocol
}
private void addClientService(NettyServerBuilder builder) {
- builder.addService(ServerInterceptors.intercept(clientProtocolService,
serverInterceptor));
+
builder.addService(ServerInterceptors.intercept(clientProtocolService.bindServiceWithZeroCopy(),
+ serverInterceptor));
}
private void addAdminService(RaftServer raftServer, NettyServerBuilder
nettyServerBuilder) {
diff --git
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java
index bb8183a24..057550c13 100644
---
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java
+++
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java
@@ -62,12 +62,14 @@ public class ZeroCopyMessageMarshaller<T extends
MessageLite> implements Prototy
private final Consumer<T> zeroCopyCount;
private final Consumer<T> nonZeroCopyCount;
+ private final Consumer<T> releasedCount;
public ZeroCopyMessageMarshaller(T defaultInstance) {
- this(defaultInstance, m -> {}, m -> {});
+ this(defaultInstance, m -> {}, m -> {}, m -> {});
}
- public ZeroCopyMessageMarshaller(T defaultInstance, Consumer<T>
zeroCopyCount, Consumer<T> nonZeroCopyCount) {
+ public ZeroCopyMessageMarshaller(T defaultInstance, Consumer<T>
zeroCopyCount, Consumer<T> nonZeroCopyCount,
+ Consumer<T> releasedCount) {
this.name = JavaUtils.getClassSimpleName(defaultInstance.getClass()) +
"-Marshaller";
@SuppressWarnings("unchecked")
final Parser<T> p = (Parser<T>) defaultInstance.getParserForType();
@@ -76,6 +78,7 @@ public class ZeroCopyMessageMarshaller<T extends MessageLite>
implements Prototy
this.zeroCopyCount = zeroCopyCount;
this.nonZeroCopyCount = nonZeroCopyCount;
+ this.releasedCount = releasedCount;
}
@Override
@@ -124,6 +127,7 @@ public class ZeroCopyMessageMarshaller<T extends
MessageLite> implements Prototy
}
try {
stream.close();
+ releasedCount.accept(message);
} catch (IOException e) {
LOG.error(name + ": Failed to close stream.", e);
}
diff --git
a/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java
b/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java
index b2788918d..8864c220c 100644
---
a/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java
+++
b/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java
@@ -56,6 +56,7 @@ import org.apache.ratis.util.Daemon;
import org.apache.ratis.util.JavaUtils;
import org.apache.ratis.util.MemoizedSupplier;
import org.apache.ratis.util.Preconditions;
+import org.apache.ratis.util.ReferenceCountedObject;
import org.apache.ratis.util.TimeDuration;
import org.apache.ratis.util.Timestamp;
@@ -534,15 +535,16 @@ class LeaderStateImpl implements LeaderState {
return pendingRequests.add(permit, request, entry);
}
- CompletableFuture<RaftClientReply> streamAsync(RaftClientRequest request) {
- return messageStreamRequests.streamAsync(request)
+ CompletableFuture<RaftClientReply>
streamAsync(ReferenceCountedObject<RaftClientRequest> requestRef) {
+ RaftClientRequest request = requestRef.get();
+ return messageStreamRequests.streamAsync(requestRef)
.thenApply(dummy -> server.newSuccessReply(request))
.exceptionally(e -> exception2RaftClientReply(request, e));
}
- CompletableFuture<RaftClientRequest>
streamEndOfRequestAsync(RaftClientRequest request) {
- return messageStreamRequests.streamEndOfRequestAsync(request)
- .thenApply(bytes -> RaftClientRequest.toWriteRequest(request,
Message.valueOf(bytes)));
+ CompletableFuture<ReferenceCountedObject<RaftClientRequest>>
streamEndOfRequestAsync(
+ ReferenceCountedObject<RaftClientRequest> requestRef) {
+ return messageStreamRequests.streamEndOfRequestAsync(requestRef);
}
CompletableFuture<RaftClientReply> addWatchRequest(RaftClientRequest
request) {
diff --git
a/ratis-server/src/main/java/org/apache/ratis/server/impl/MessageStreamRequests.java
b/ratis-server/src/main/java/org/apache/ratis/server/impl/MessageStreamRequests.java
index ac81b348b..c00c57b36 100644
---
a/ratis-server/src/main/java/org/apache/ratis/server/impl/MessageStreamRequests.java
+++
b/ratis-server/src/main/java/org/apache/ratis/server/impl/MessageStreamRequests.java
@@ -25,12 +25,15 @@ import org.apache.ratis.protocol.exceptions.StreamException;
import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
import org.apache.ratis.util.JavaUtils;
import org.apache.ratis.util.Preconditions;
+import org.apache.ratis.util.ReferenceCountedObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ConcurrentMap;
class MessageStreamRequests {
public static final Logger LOG =
LoggerFactory.getLogger(MessageStreamRequests.class);
@@ -39,12 +42,14 @@ class MessageStreamRequests {
private final ClientInvocationId key;
private long nextId = -1;
private ByteString bytes = ByteString.EMPTY;
+ private final List<ReferenceCountedObject<RaftClientRequest>> pendingRefs
= new LinkedList<>();
PendingStream(ClientInvocationId key) {
this.key = key;
}
- synchronized CompletableFuture<ByteString> append(long messageId, Message
message) {
+ synchronized CompletableFuture<ByteString> append(long messageId,
+ ReferenceCountedObject<RaftClientRequest> requestRef) {
if (nextId == -1) {
nextId = messageId;
} else if (messageId != nextId) {
@@ -52,27 +57,38 @@ class MessageStreamRequests {
"Unexpected message id in " + key + ": messageId = " + messageId +
" != nextId = " + nextId));
}
nextId++;
+ final Message message = requestRef.retain().getMessage();
+ pendingRefs.add(requestRef);
bytes = bytes.concat(message.getContent());
return CompletableFuture.completedFuture(bytes);
}
- synchronized CompletableFuture<ByteString> getBytes(long messageId,
Message message) {
- return append(messageId, message);
+ synchronized CompletableFuture<ReferenceCountedObject<RaftClientRequest>>
getWriteRequest(long messageId,
+ ReferenceCountedObject<RaftClientRequest> requestRef) {
+ return append(messageId, requestRef)
+ .thenApply(appended ->
RaftClientRequest.toWriteRequest(requestRef.get(), () -> appended))
+ .thenApply(request ->
ReferenceCountedObject.delegateFrom(pendingRefs, request));
+ }
+
+ synchronized void clear() {
+ pendingRefs.forEach(ReferenceCountedObject::release);
+ pendingRefs.clear();
}
}
static class StreamMap {
- private final ConcurrentMap<ClientInvocationId, PendingStream> map = new
ConcurrentHashMap<>();
+ private final Map<ClientInvocationId, PendingStream> map = new HashMap<>();
- PendingStream computeIfAbsent(ClientInvocationId key) {
+ synchronized PendingStream computeIfAbsent(ClientInvocationId key) {
return map.computeIfAbsent(key, PendingStream::new);
}
- PendingStream remove(ClientInvocationId key) {
+ synchronized PendingStream remove(ClientInvocationId key) {
return map.remove(key);
}
- void clear() {
+ synchronized void clear() {
+ map.values().forEach(PendingStream::clear);
map.clear();
}
}
@@ -84,15 +100,18 @@ class MessageStreamRequests {
this.name = name + "-" + JavaUtils.getClassSimpleName(getClass());
}
- CompletableFuture<?> streamAsync(RaftClientRequest request) {
+ CompletableFuture<?> streamAsync(ReferenceCountedObject<RaftClientRequest>
requestRef) {
+ final RaftClientRequest request = requestRef.get();
final MessageStreamRequestTypeProto stream =
request.getType().getMessageStream();
Preconditions.assertTrue(!stream.getEndOfRequest());
final ClientInvocationId key =
ClientInvocationId.valueOf(request.getClientId(), stream.getStreamId());
final PendingStream pending = streams.computeIfAbsent(key);
- return pending.append(stream.getMessageId(), request.getMessage());
+ return pending.append(stream.getMessageId(), requestRef);
}
- CompletableFuture<ByteString> streamEndOfRequestAsync(RaftClientRequest
request) {
+ CompletableFuture<ReferenceCountedObject<RaftClientRequest>>
streamEndOfRequestAsync(
+ ReferenceCountedObject<RaftClientRequest> requestRef) {
+ final RaftClientRequest request = requestRef.get();
final MessageStreamRequestTypeProto stream =
request.getType().getMessageStream();
Preconditions.assertTrue(stream.getEndOfRequest());
final ClientInvocationId key =
ClientInvocationId.valueOf(request.getClientId(), stream.getStreamId());
@@ -101,7 +120,7 @@ class MessageStreamRequests {
if (pending == null) {
return JavaUtils.completeExceptionally(new StreamException(name + ": " +
key + " not found"));
}
- return pending.getBytes(stream.getMessageId(), request.getMessage());
+ return pending.getWriteRequest(stream.getMessageId(), requestRef);
}
void clear() {
diff --git
a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
index 51067a87a..64fa52029 100644
---
a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
+++
b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
@@ -1102,21 +1102,22 @@ class RaftServerImpl implements RaftServer.Division,
}
if (request.getType().getMessageStream().getEndOfRequest()) {
- final CompletableFuture<RaftClientRequest> f =
streamEndOfRequestAsync(request);
+ final CompletableFuture<ReferenceCountedObject<RaftClientRequest>> f =
streamEndOfRequestAsync(requestRef);
if (f.isCompletedExceptionally()) {
return f.thenApply(r -> null);
}
// the message stream has ended and the request become a WRITE request
- return replyFuture(requestRef.delegate(f.join()));
+ return replyFuture(f.join());
}
return role.getLeaderState()
- .map(ls -> ls.streamAsync(request))
+ .map(ls -> ls.streamAsync(requestRef))
.orElseGet(() -> CompletableFuture.completedFuture(
newExceptionReply(request, generateNotLeaderException())));
}
- private CompletableFuture<RaftClientRequest>
streamEndOfRequestAsync(RaftClientRequest request) {
+ private CompletableFuture<ReferenceCountedObject<RaftClientRequest>>
streamEndOfRequestAsync(
+ ReferenceCountedObject<RaftClientRequest> request) {
return role.getLeaderState()
.map(ls -> ls.streamEndOfRequestAsync(request))
.orElse(null);
diff --git
a/ratis-server/src/test/java/org/apache/ratis/statemachine/impl/SimpleStateMachine4Testing.java
b/ratis-server/src/test/java/org/apache/ratis/statemachine/impl/SimpleStateMachine4Testing.java
index 312c9508d..07073be52 100644
---
a/ratis-server/src/test/java/org/apache/ratis/statemachine/impl/SimpleStateMachine4Testing.java
+++
b/ratis-server/src/test/java/org/apache/ratis/statemachine/impl/SimpleStateMachine4Testing.java
@@ -328,7 +328,7 @@ public class SimpleStateMachine4Testing extends
BaseStateMachine {
final String string = request.getContent().toStringUtf8();
Exception exception;
try {
- LOG.info("query " + string);
+ LOG.info("query {}, all available: {}", string, dataMap.keySet());
final LogEntryProto entry = dataMap.get(string);
if (entry != null) {
return
CompletableFuture.completedFuture(Message.valueOf(entry.toByteString()));
diff --git
a/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcZeroCopyTestServer.java
b/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcZeroCopyTestServer.java
index e1bfe4e22..21db98d4c 100644
---
a/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcZeroCopyTestServer.java
+++
b/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcZeroCopyTestServer.java
@@ -75,12 +75,14 @@ class GrpcZeroCopyTestServer implements Closeable {
private final Count zeroCopyCount = new Count();
private final Count nonZeroCopyCount = new Count();
+ private final Count releasedCount = new Count();
private final Server server;
private final ZeroCopyMessageMarshaller<BinaryRequest> marshaller = new
ZeroCopyMessageMarshaller<>(
BinaryRequest.getDefaultInstance(),
zeroCopyCount::inc,
- nonZeroCopyCount::inc);
+ nonZeroCopyCount::inc,
+ releasedCount::inc);
GrpcZeroCopyTestServer(int port) {
final GreeterImpl greeter = new GreeterImpl();