This is an automated email from the ASF dual-hosted git repository.

szetszwo pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ratis.git


The following commit(s) were added to refs/heads/master by this push:
     new 84285d3ea RATIS-1925. Support Zero-Copy in GrpcClientProtocolService 
(#1007)
84285d3ea is described below

commit 84285d3eadca2f734157d2970fc2c5cce02331e1
Author: Duong Nguyen <[email protected]>
AuthorDate: Fri Jan 12 19:00:53 2024 -0800

    RATIS-1925. Support Zero-Copy in GrpcClientProtocolService (#1007)
---
 .../apache/ratis/protocol/RaftClientRequest.java   |  10 +-
 .../apache/ratis/util/ReferenceCountedObject.java  |  21 ++++
 .../main/java/org/apache/ratis/grpc/GrpcUtil.java  |  26 ++++
 .../apache/ratis/grpc/metrics/ZeroCopyMetrics.java |  58 +++++++++
 .../grpc/server/GrpcClientProtocolService.java     | 134 +++++++++++++--------
 .../org/apache/ratis/grpc/server/GrpcService.java  |   8 +-
 .../ratis/grpc/util/ZeroCopyMessageMarshaller.java |   8 +-
 .../apache/ratis/server/impl/LeaderStateImpl.java  |  12 +-
 .../ratis/server/impl/MessageStreamRequests.java   |  45 +++++--
 .../apache/ratis/server/impl/RaftServerImpl.java   |   9 +-
 .../impl/SimpleStateMachine4Testing.java           |   2 +-
 .../ratis/grpc/util/GrpcZeroCopyTestServer.java    |   4 +-
 12 files changed, 258 insertions(+), 79 deletions(-)

diff --git 
a/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientRequest.java 
b/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientRequest.java
index ed41f1ea2..18c157130 100644
--- 
a/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientRequest.java
+++ 
b/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientRequest.java
@@ -474,7 +474,13 @@ public class RaftClientRequest extends RaftClientMessage {
 
   @Override
   public String toString() {
-    return super.toString() + ", seq=" + 
ProtoUtils.toString(slidingWindowEntry) + ", "
-        + type + ", " + getMessage();
+    return toStringShort() + ", " + getMessage();
+  }
+
+  /**
+   * @return a short string which does not include {@link #message}.
+   */
+  public String toStringShort() {
+    return super.toString() + ", seq=" + 
ProtoUtils.toString(slidingWindowEntry) + ", " + type;
   }
 }
diff --git 
a/ratis-common/src/main/java/org/apache/ratis/util/ReferenceCountedObject.java 
b/ratis-common/src/main/java/org/apache/ratis/util/ReferenceCountedObject.java
index 3f72f5ffe..815b90dbc 100644
--- 
a/ratis-common/src/main/java/org/apache/ratis/util/ReferenceCountedObject.java
+++ 
b/ratis-common/src/main/java/org/apache/ratis/util/ReferenceCountedObject.java
@@ -19,6 +19,7 @@ package org.apache.ratis.util;
 
 import org.apache.ratis.util.function.UncheckedAutoCloseableSupplier;
 
+import java.util.Collection;
 import java.util.Objects;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
@@ -101,6 +102,26 @@ public interface ReferenceCountedObject<T> {
     return wrap(value, () -> {}, ignored -> {});
   }
 
+  static <T, V> ReferenceCountedObject<V> 
delegateFrom(Collection<ReferenceCountedObject<T>> fromRefs, V value) {
+    return new ReferenceCountedObject<V>() {
+      @Override
+      public V get() {
+        return value;
+      }
+
+      @Override
+      public V retain() {
+        fromRefs.forEach(ReferenceCountedObject::retain);
+        return value;
+      }
+
+      @Override
+      public boolean release() {
+        return 
fromRefs.stream().map(ReferenceCountedObject::release).allMatch(r -> r);
+      }
+    };
+  }
+
   /**
    * @return a {@link ReferenceCountedObject} of the given value by delegating 
to this object.
    */
diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcUtil.java 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcUtil.java
index 22653b6ef..0baefa2d3 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcUtil.java
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcUtil.java
@@ -24,8 +24,12 @@ import org.apache.ratis.security.TlsConf.TrustManagerConf;
 import org.apache.ratis.security.TlsConf.CertificatesConf;
 import org.apache.ratis.security.TlsConf.PrivateKeyConf;
 import org.apache.ratis.security.TlsConf.KeyManagerConf;
+import org.apache.ratis.thirdparty.com.google.protobuf.MessageLite;
 import org.apache.ratis.thirdparty.io.grpc.ManagedChannel;
 import org.apache.ratis.thirdparty.io.grpc.Metadata;
+import org.apache.ratis.thirdparty.io.grpc.MethodDescriptor;
+import org.apache.ratis.thirdparty.io.grpc.ServerCallHandler;
+import org.apache.ratis.thirdparty.io.grpc.ServerServiceDefinition;
 import org.apache.ratis.thirdparty.io.grpc.Status;
 import org.apache.ratis.thirdparty.io.grpc.StatusRuntimeException;
 import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver;
@@ -304,4 +308,26 @@ public interface GrpcUtil {
       b.keyManager(privateKey.get(), certificates.get());
     }
   }
+
+  /**
+   * Used to add a method to Service definition with a custom request 
marshaller.
+   *
+   * @param orig original service definition.
+   * @param newServiceBuilder builder of the new service definition.
+   * @param origMethod the original method definition.
+   * @param customMarshaller custom marshaller to be set for the method.
+   * @param <Req>
+   * @param <Resp>
+   */
+  static <Req extends MessageLite, Resp> void addMethodWithCustomMarshaller(
+      ServerServiceDefinition orig, ServerServiceDefinition.Builder 
newServiceBuilder,
+      MethodDescriptor<Req, Resp> origMethod, 
MethodDescriptor.PrototypeMarshaller<Req> customMarshaller) {
+    MethodDescriptor<Req, Resp> newMethod = origMethod.toBuilder()
+        .setRequestMarshaller(customMarshaller)
+        .build();
+    @SuppressWarnings("unchecked")
+    ServerCallHandler<Req, Resp> serverCallHandler =
+        (ServerCallHandler<Req, Resp>) 
orig.getMethod(newMethod.getFullMethodName()).getServerCallHandler();
+    newServiceBuilder.addMethod(newMethod, serverCallHandler);
+  }
 }
diff --git 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/metrics/ZeroCopyMetrics.java 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/metrics/ZeroCopyMetrics.java
new file mode 100644
index 000000000..20da4ee63
--- /dev/null
+++ 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/metrics/ZeroCopyMetrics.java
@@ -0,0 +1,58 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ratis.grpc.metrics;
+
+import org.apache.ratis.metrics.LongCounter;
+import org.apache.ratis.metrics.MetricRegistryInfo;
+import org.apache.ratis.metrics.RatisMetricRegistry;
+import org.apache.ratis.metrics.RatisMetrics;
+import org.apache.ratis.thirdparty.com.google.protobuf.AbstractMessage;
+
+public class ZeroCopyMetrics extends RatisMetrics {
+  private static final String RATIS_GRPC_METRICS_APP_NAME = "ratis_grpc";
+  private static final String RATIS_GRPC_METRICS_COMP_NAME = "zero_copy";
+  private static final String RATIS_GRPC_METRICS_DESC = "Metrics for Ratis 
Grpc Zero copy";
+
+  private final LongCounter zeroCopyMessages = 
getRegistry().counter("num_zero_copy_messages");
+  private final LongCounter nonZeroCopyMessages = 
getRegistry().counter("num_non_zero_copy_messages");
+  private final LongCounter releasedMessages = 
getRegistry().counter("num_released_messages");
+
+  public ZeroCopyMetrics() {
+    super(createRegistry());
+  }
+
+  private static RatisMetricRegistry createRegistry() {
+    return create(new MetricRegistryInfo("",
+        RATIS_GRPC_METRICS_APP_NAME,
+        RATIS_GRPC_METRICS_COMP_NAME, RATIS_GRPC_METRICS_DESC));
+  }
+
+
+  public void onZeroCopyMessage(AbstractMessage ignored) {
+    zeroCopyMessages.inc();
+  }
+
+  public void onNonZeroCopyMessage(AbstractMessage ignored) {
+    nonZeroCopyMessages.inc();
+  }
+
+  public void onReleasedMessage(AbstractMessage ignored) {
+    releasedMessages.inc();
+  }
+
+}
\ No newline at end of file
diff --git 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcClientProtocolService.java
 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcClientProtocolService.java
index 9c1968467..e8de4def0 100644
--- 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcClientProtocolService.java
+++ 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcClientProtocolService.java
@@ -19,10 +19,13 @@ package org.apache.ratis.grpc.server;
 
 import org.apache.ratis.client.impl.ClientProtoUtils;
 import org.apache.ratis.grpc.GrpcUtil;
+import org.apache.ratis.grpc.metrics.ZeroCopyMetrics;
+import org.apache.ratis.grpc.util.ZeroCopyMessageMarshaller;
 import org.apache.ratis.protocol.*;
 import org.apache.ratis.protocol.exceptions.AlreadyClosedException;
 import org.apache.ratis.protocol.exceptions.GroupMismatchException;
 import org.apache.ratis.protocol.exceptions.RaftException;
+import org.apache.ratis.thirdparty.io.grpc.ServerServiceDefinition;
 import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver;
 import org.apache.ratis.proto.RaftProtos.RaftClientReplyProto;
 import org.apache.ratis.proto.RaftProtos.RaftClientRequestProto;
@@ -30,16 +33,15 @@ import 
org.apache.ratis.proto.grpc.RaftClientProtocolServiceGrpc.RaftClientProto
 import org.apache.ratis.util.CollectionUtils;
 import org.apache.ratis.util.JavaUtils;
 import org.apache.ratis.util.Preconditions;
+import org.apache.ratis.util.ReferenceCountedObject;
 import org.apache.ratis.util.SlidingWindow;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.io.IOException;
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.Map;
 import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.CompletionException;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.atomic.AtomicBoolean;
@@ -48,15 +50,21 @@ import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
 
+import static org.apache.ratis.grpc.GrpcUtil.addMethodWithCustomMarshaller;
+import static 
org.apache.ratis.proto.grpc.RaftClientProtocolServiceGrpc.getOrderedMethod;
+import static 
org.apache.ratis.proto.grpc.RaftClientProtocolServiceGrpc.getUnorderedMethod;
+
 class GrpcClientProtocolService extends RaftClientProtocolServiceImplBase {
   private static final Logger LOG = 
LoggerFactory.getLogger(GrpcClientProtocolService.class);
 
   private static class PendingOrderedRequest implements 
SlidingWindow.ServerSideRequest<RaftClientReply> {
+    private final ReferenceCountedObject<RaftClientRequest> requestRef;
     private final RaftClientRequest request;
     private final AtomicReference<RaftClientReply> reply = new 
AtomicReference<>();
 
-    PendingOrderedRequest(RaftClientRequest request) {
-      this.request = request;
+    PendingOrderedRequest(ReferenceCountedObject<RaftClientRequest> 
requestRef) {
+      this.requestRef = requestRef;
+      this.request = requestRef != null ? requestRef.get() : null;
     }
 
     @Override
@@ -76,15 +84,16 @@ class GrpcClientProtocolService extends 
RaftClientProtocolServiceImplBase {
     @Override
     public void setReply(RaftClientReply r) {
       final boolean set = reply.compareAndSet(null, r);
-      Preconditions.assertTrue(set, () -> "Reply is already set: request=" + 
request + ", reply=" + reply);
+      Preconditions.assertTrue(set, () -> "Reply is already set: request=" +
+          request.toStringShort() + ", reply=" + reply);
     }
 
     RaftClientReply getReply() {
       return reply.get();
     }
 
-    RaftClientRequest getRequest() {
-      return request;
+    ReferenceCountedObject<RaftClientRequest> getRequestRef() {
+      return requestRef;
     }
 
     @Override
@@ -135,18 +144,31 @@ class GrpcClientProtocolService extends 
RaftClientProtocolServiceImplBase {
   private final ExecutorService executor;
 
   private final OrderedStreamObservers orderedStreamObservers = new 
OrderedStreamObservers();
+  private final ZeroCopyMessageMarshaller<RaftClientRequestProto> 
zeroCopyRequestMarshaller;
 
   GrpcClientProtocolService(Supplier<RaftPeerId> idSupplier, 
RaftClientAsynchronousProtocol protocol,
-      ExecutorService executor) {
+      ExecutorService executor, ZeroCopyMetrics zeroCopyMetrics) {
     this.idSupplier = idSupplier;
     this.protocol = protocol;
     this.executor = executor;
+    this.zeroCopyRequestMarshaller = new 
ZeroCopyMessageMarshaller<>(RaftClientRequestProto.getDefaultInstance(),
+        zeroCopyMetrics::onZeroCopyMessage, 
zeroCopyMetrics::onNonZeroCopyMessage, zeroCopyMetrics::onReleasedMessage);
   }
 
   RaftPeerId getId() {
     return idSupplier.get();
   }
 
+  ServerServiceDefinition bindServiceWithZeroCopy() {
+    ServerServiceDefinition orig = super.bindService();
+    ServerServiceDefinition.Builder builder = 
ServerServiceDefinition.builder(orig.getServiceDescriptor().getName());
+
+    addMethodWithCustomMarshaller(orig, builder, getOrderedMethod(), 
zeroCopyRequestMarshaller);
+    addMethodWithCustomMarshaller(orig, builder, getUnorderedMethod(), 
zeroCopyRequestMarshaller);
+
+    return builder.build();
+  }
+
   @Override
   public StreamObserver<RaftClientRequestProto> 
ordered(StreamObserver<RaftClientReplyProto> responseObserver) {
     final OrderedRequestStreamObserver so = new 
OrderedRequestStreamObserver(responseObserver);
@@ -220,31 +242,38 @@ class GrpcClientProtocolService extends 
RaftClientProtocolServiceImplBase {
       return isClosed.get();
     }
 
-    CompletableFuture<Void> processClientRequest(RaftClientRequest request, 
Consumer<RaftClientReply> replyHandler) {
-      try {
-        final String errMsg = LOG.isDebugEnabled() ? "processClientRequest for 
" + request : "";
-        return protocol.submitClientRequestAsync(request
-        ).thenAcceptAsync(replyHandler, executor
-        ).exceptionally(exception -> {
-          // TODO: the exception may be from either raft or state machine.
-          // Currently we skip all the following responses when getting an
-          // exception from the state machine.
-          responseError(exception, () -> errMsg);
-          return null;
-        });
-      } catch (IOException e) {
-        throw new CompletionException("Failed processClientRequest for " + 
request + " in " + name, e);
-      }
+    CompletableFuture<Void> 
processClientRequest(ReferenceCountedObject<RaftClientRequest> requestRef,
+        Consumer<RaftClientReply> replyHandler) {
+      final String errMsg = LOG.isDebugEnabled() ? "processClientRequest for " 
+ requestRef.get() : "";
+      return protocol.submitClientRequestAsync(requestRef
+      ).thenAcceptAsync(replyHandler, executor
+      ).exceptionally(exception -> {
+        // TODO: the exception may be from either raft or state machine.
+        // Currently we skip all the following responses when getting an
+        // exception from the state machine.
+        responseError(exception, () -> errMsg);
+        return null;
+      });
     }
 
-    abstract void processClientRequest(RaftClientRequest request);
+    abstract void 
processClientRequest(ReferenceCountedObject<RaftClientRequest> requestRef);
 
     @Override
     public void onNext(RaftClientRequestProto request) {
+      ReferenceCountedObject<RaftClientRequest> requestRef = null;
       try {
         final RaftClientRequest r = 
ClientProtoUtils.toRaftClientRequest(request);
-        processClientRequest(r);
+        requestRef = ReferenceCountedObject.wrap(r, () -> {}, released -> {
+          if (released) {
+            zeroCopyRequestMarshaller.release(request);
+          }
+        });
+
+        processClientRequest(requestRef);
       } catch (Exception e) {
+        if (requestRef == null) {
+          zeroCopyRequestMarshaller.release(request);
+        }
         responseError(e, () -> "onNext for " + 
ClientProtoUtils.toString(request) + " in " + name);
       }
     }
@@ -278,15 +307,18 @@ class GrpcClientProtocolService extends 
RaftClientProtocolServiceImplBase {
     }
 
     @Override
-    void processClientRequest(RaftClientRequest request) {
-      final CompletableFuture<Void> f = processClientRequest(request, reply -> 
{
+    void processClientRequest(ReferenceCountedObject<RaftClientRequest> 
requestRef) {
+      final RaftClientRequest request = requestRef.retain();
+      final long callId = request.getCallId();
+
+      final CompletableFuture<Void> f = processClientRequest(requestRef, reply 
-> {
         if (!reply.isSuccess()) {
-          LOG.info("Failed " + request + ", reply=" + reply);
+          LOG.info("Failed {}, reply={}", request, reply);
         }
         final RaftClientReplyProto proto = 
ClientProtoUtils.toRaftClientReplyProto(reply);
         responseNext(proto);
-      });
-      final long callId = request.getCallId();
+      }).whenComplete((r, e) -> requestRef.release());
+
       put(callId, f);
       f.thenAccept(dummy -> remove(callId));
     }
@@ -329,31 +361,35 @@ class GrpcClientProtocolService extends 
RaftClientProtocolServiceImplBase {
 
     void processClientRequest(PendingOrderedRequest pending) {
       final long seq = pending.getSeqNum();
-      processClientRequest(pending.getRequest(),
+      processClientRequest(pending.getRequestRef(),
           reply -> slidingWindow.receiveReply(seq, reply, this::sendReply));
     }
 
     @Override
-    void processClientRequest(RaftClientRequest r) {
-      if (isClosed()) {
-        final AlreadyClosedException exception = new 
AlreadyClosedException(getName() + ": the stream is closed");
-        responseError(exception, () -> "processClientRequest (stream already 
closed) for " + r);
-      }
+    void processClientRequest(ReferenceCountedObject<RaftClientRequest> 
requestRef) {
+      final RaftClientRequest request = requestRef.retain();
+      try {
+        if (isClosed()) {
+          final AlreadyClosedException exception = new 
AlreadyClosedException(getName() + ": the stream is closed");
+          responseError(exception, () -> "processClientRequest (stream already 
closed) for " + request);
+        }
 
-      final RaftGroupId requestGroupId = r.getRaftGroupId();
-      // use the group id in the first request as the group id of this observer
-      final RaftGroupId updated = groupId.updateAndGet(g -> g != null ? g: 
requestGroupId);
-      final PendingOrderedRequest pending = new PendingOrderedRequest(r);
-
-      if (!requestGroupId.equals(updated)) {
-        final GroupMismatchException exception = new 
GroupMismatchException(getId()
-            + ": The group (" + requestGroupId + ") of " + r.getClientId()
-            + " does not match the group (" + updated + ") of the " + 
JavaUtils.getClassSimpleName(getClass()));
-        responseError(exception, () -> "processClientRequest (Group 
mismatched) for " + r);
-        return;
+        final RaftGroupId requestGroupId = request.getRaftGroupId();
+        // use the group id in the first request as the group id of this 
observer
+        final RaftGroupId updated = groupId.updateAndGet(g -> g != null ? g : 
requestGroupId);
+        final PendingOrderedRequest pending = new 
PendingOrderedRequest(requestRef);
+
+        if (!requestGroupId.equals(updated)) {
+          final GroupMismatchException exception = new 
GroupMismatchException(getId()
+              + ": The group (" + requestGroupId + ") of " + 
request.getClientId()
+              + " does not match the group (" + updated + ") of the " + 
JavaUtils.getClassSimpleName(getClass()));
+          responseError(exception, () -> "processClientRequest (Group 
mismatched) for " + request);
+          return;
+        }
+        slidingWindow.receivedRequest(pending, this::processClientRequest);
+      } finally {
+        requestRef.release();
       }
-
-      slidingWindow.receivedRequest(pending, this::processClientRequest);
     }
 
     private void sendReply(PendingOrderedRequest ready) {
diff --git 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcService.java 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcService.java
index 097900a0f..d89afd565 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcService.java
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcService.java
@@ -21,6 +21,7 @@ import org.apache.ratis.conf.RaftProperties;
 import org.apache.ratis.grpc.GrpcConfigKeys;
 import org.apache.ratis.grpc.GrpcTlsConfig;
 import org.apache.ratis.grpc.GrpcUtil;
+import org.apache.ratis.grpc.metrics.ZeroCopyMetrics;
 import org.apache.ratis.grpc.metrics.intercept.server.MetricServerInterceptor;
 import org.apache.ratis.protocol.RaftGroupId;
 import org.apache.ratis.protocol.RaftPeerId;
@@ -153,6 +154,7 @@ public final class GrpcService extends 
RaftServerRpcWithProxy<GrpcServerProtocol
   private final GrpcClientProtocolService clientProtocolService;
 
   private final MetricServerInterceptor serverInterceptor;
+  private final ZeroCopyMetrics zeroCopyMetrics;
 
   public MetricServerInterceptor getServerInterceptor() {
     return serverInterceptor;
@@ -199,7 +201,8 @@ public final class GrpcService extends 
RaftServerRpcWithProxy<GrpcServerProtocol
         GrpcConfigKeys.Server.asyncRequestThreadPoolCached(properties),
         GrpcConfigKeys.Server.asyncRequestThreadPoolSize(properties),
         getId() + "-request-");
-    this.clientProtocolService = new GrpcClientProtocolService(idSupplier, 
raftServer, executor);
+    this.zeroCopyMetrics = new ZeroCopyMetrics();
+    this.clientProtocolService = new GrpcClientProtocolService(idSupplier, 
raftServer, executor, zeroCopyMetrics);
 
     this.serverInterceptor = new MetricServerInterceptor(
         idSupplier,
@@ -252,7 +255,8 @@ public final class GrpcService extends 
RaftServerRpcWithProxy<GrpcServerProtocol
   }
 
   private void addClientService(NettyServerBuilder builder) {
-    builder.addService(ServerInterceptors.intercept(clientProtocolService, 
serverInterceptor));
+    
builder.addService(ServerInterceptors.intercept(clientProtocolService.bindServiceWithZeroCopy(),
+        serverInterceptor));
   }
 
   private void addAdminService(RaftServer raftServer, NettyServerBuilder 
nettyServerBuilder) {
diff --git 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java
 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java
index bb8183a24..057550c13 100644
--- 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java
+++ 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java
@@ -62,12 +62,14 @@ public class ZeroCopyMessageMarshaller<T extends 
MessageLite> implements Prototy
 
   private final Consumer<T> zeroCopyCount;
   private final Consumer<T> nonZeroCopyCount;
+  private final Consumer<T> releasedCount;
 
   public ZeroCopyMessageMarshaller(T defaultInstance) {
-    this(defaultInstance, m -> {}, m -> {});
+    this(defaultInstance, m -> {}, m -> {}, m -> {});
   }
 
-  public ZeroCopyMessageMarshaller(T defaultInstance, Consumer<T> 
zeroCopyCount, Consumer<T> nonZeroCopyCount) {
+  public ZeroCopyMessageMarshaller(T defaultInstance, Consumer<T> 
zeroCopyCount, Consumer<T> nonZeroCopyCount,
+      Consumer<T> releasedCount) {
     this.name = JavaUtils.getClassSimpleName(defaultInstance.getClass()) + 
"-Marshaller";
     @SuppressWarnings("unchecked")
     final Parser<T> p = (Parser<T>) defaultInstance.getParserForType();
@@ -76,6 +78,7 @@ public class ZeroCopyMessageMarshaller<T extends MessageLite> 
implements Prototy
 
     this.zeroCopyCount = zeroCopyCount;
     this.nonZeroCopyCount = nonZeroCopyCount;
+    this.releasedCount = releasedCount;
   }
 
   @Override
@@ -124,6 +127,7 @@ public class ZeroCopyMessageMarshaller<T extends 
MessageLite> implements Prototy
     }
     try {
       stream.close();
+      releasedCount.accept(message);
     } catch (IOException e) {
       LOG.error(name + ": Failed to close stream.", e);
     }
diff --git 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java
index b2788918d..8864c220c 100644
--- 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java
+++ 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java
@@ -56,6 +56,7 @@ import org.apache.ratis.util.Daemon;
 import org.apache.ratis.util.JavaUtils;
 import org.apache.ratis.util.MemoizedSupplier;
 import org.apache.ratis.util.Preconditions;
+import org.apache.ratis.util.ReferenceCountedObject;
 import org.apache.ratis.util.TimeDuration;
 import org.apache.ratis.util.Timestamp;
 
@@ -534,15 +535,16 @@ class LeaderStateImpl implements LeaderState {
     return pendingRequests.add(permit, request, entry);
   }
 
-  CompletableFuture<RaftClientReply> streamAsync(RaftClientRequest request) {
-    return messageStreamRequests.streamAsync(request)
+  CompletableFuture<RaftClientReply> 
streamAsync(ReferenceCountedObject<RaftClientRequest> requestRef) {
+    RaftClientRequest request = requestRef.get();
+    return messageStreamRequests.streamAsync(requestRef)
         .thenApply(dummy -> server.newSuccessReply(request))
         .exceptionally(e -> exception2RaftClientReply(request, e));
   }
 
-  CompletableFuture<RaftClientRequest> 
streamEndOfRequestAsync(RaftClientRequest request) {
-    return messageStreamRequests.streamEndOfRequestAsync(request)
-        .thenApply(bytes -> RaftClientRequest.toWriteRequest(request, 
Message.valueOf(bytes)));
+  CompletableFuture<ReferenceCountedObject<RaftClientRequest>> 
streamEndOfRequestAsync(
+      ReferenceCountedObject<RaftClientRequest> requestRef) {
+    return messageStreamRequests.streamEndOfRequestAsync(requestRef);
   }
 
   CompletableFuture<RaftClientReply> addWatchRequest(RaftClientRequest 
request) {
diff --git 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/MessageStreamRequests.java
 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/MessageStreamRequests.java
index ac81b348b..c00c57b36 100644
--- 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/MessageStreamRequests.java
+++ 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/MessageStreamRequests.java
@@ -25,12 +25,15 @@ import org.apache.ratis.protocol.exceptions.StreamException;
 import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
 import org.apache.ratis.util.JavaUtils;
 import org.apache.ratis.util.Preconditions;
+import org.apache.ratis.util.ReferenceCountedObject;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
 import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ConcurrentMap;
 
 class MessageStreamRequests {
   public static final Logger LOG = 
LoggerFactory.getLogger(MessageStreamRequests.class);
@@ -39,12 +42,14 @@ class MessageStreamRequests {
     private final ClientInvocationId key;
     private long nextId = -1;
     private ByteString bytes = ByteString.EMPTY;
+    private final List<ReferenceCountedObject<RaftClientRequest>> pendingRefs 
= new LinkedList<>();
 
     PendingStream(ClientInvocationId key) {
       this.key = key;
     }
 
-    synchronized CompletableFuture<ByteString> append(long messageId, Message 
message) {
+    synchronized CompletableFuture<ByteString> append(long messageId,
+        ReferenceCountedObject<RaftClientRequest> requestRef) {
       if (nextId == -1) {
         nextId = messageId;
       } else if (messageId != nextId) {
@@ -52,27 +57,38 @@ class MessageStreamRequests {
             "Unexpected message id in " + key + ": messageId = " + messageId + 
" != nextId = " + nextId));
       }
       nextId++;
+      final Message message = requestRef.retain().getMessage();
+      pendingRefs.add(requestRef);
       bytes = bytes.concat(message.getContent());
       return CompletableFuture.completedFuture(bytes);
     }
 
-    synchronized CompletableFuture<ByteString> getBytes(long messageId, 
Message message) {
-      return append(messageId, message);
+    synchronized CompletableFuture<ReferenceCountedObject<RaftClientRequest>> 
getWriteRequest(long messageId,
+        ReferenceCountedObject<RaftClientRequest> requestRef) {
+      return append(messageId, requestRef)
+          .thenApply(appended -> 
RaftClientRequest.toWriteRequest(requestRef.get(), () -> appended))
+          .thenApply(request -> 
ReferenceCountedObject.delegateFrom(pendingRefs, request));
+    }
+
+    synchronized void clear() {
+      pendingRefs.forEach(ReferenceCountedObject::release);
+      pendingRefs.clear();
     }
   }
 
   static class StreamMap {
-    private final ConcurrentMap<ClientInvocationId, PendingStream> map = new 
ConcurrentHashMap<>();
+    private final Map<ClientInvocationId, PendingStream> map = new HashMap<>();
 
-    PendingStream computeIfAbsent(ClientInvocationId key) {
+    synchronized PendingStream computeIfAbsent(ClientInvocationId key) {
       return map.computeIfAbsent(key, PendingStream::new);
     }
 
-    PendingStream remove(ClientInvocationId key) {
+    synchronized PendingStream remove(ClientInvocationId key) {
       return map.remove(key);
     }
 
-    void clear() {
+    synchronized void clear() {
+      map.values().forEach(PendingStream::clear);
       map.clear();
     }
   }
@@ -84,15 +100,18 @@ class MessageStreamRequests {
     this.name = name + "-" + JavaUtils.getClassSimpleName(getClass());
   }
 
-  CompletableFuture<?> streamAsync(RaftClientRequest request) {
+  CompletableFuture<?> streamAsync(ReferenceCountedObject<RaftClientRequest> 
requestRef) {
+    final RaftClientRequest request = requestRef.get();
     final MessageStreamRequestTypeProto stream = 
request.getType().getMessageStream();
     Preconditions.assertTrue(!stream.getEndOfRequest());
     final ClientInvocationId key = 
ClientInvocationId.valueOf(request.getClientId(), stream.getStreamId());
     final PendingStream pending = streams.computeIfAbsent(key);
-    return pending.append(stream.getMessageId(), request.getMessage());
+    return pending.append(stream.getMessageId(), requestRef);
   }
 
-  CompletableFuture<ByteString> streamEndOfRequestAsync(RaftClientRequest 
request) {
+  CompletableFuture<ReferenceCountedObject<RaftClientRequest>> 
streamEndOfRequestAsync(
+      ReferenceCountedObject<RaftClientRequest> requestRef) {
+    final RaftClientRequest request = requestRef.get();
     final MessageStreamRequestTypeProto stream = 
request.getType().getMessageStream();
     Preconditions.assertTrue(stream.getEndOfRequest());
     final ClientInvocationId key = 
ClientInvocationId.valueOf(request.getClientId(), stream.getStreamId());
@@ -101,7 +120,7 @@ class MessageStreamRequests {
     if (pending == null) {
       return JavaUtils.completeExceptionally(new StreamException(name + ": " + 
key + " not found"));
     }
-    return pending.getBytes(stream.getMessageId(), request.getMessage());
+    return pending.getWriteRequest(stream.getMessageId(), requestRef);
   }
 
   void clear() {
diff --git 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
index 51067a87a..64fa52029 100644
--- 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
+++ 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
@@ -1102,21 +1102,22 @@ class RaftServerImpl implements RaftServer.Division,
     }
 
     if (request.getType().getMessageStream().getEndOfRequest()) {
-      final CompletableFuture<RaftClientRequest> f = 
streamEndOfRequestAsync(request);
+      final CompletableFuture<ReferenceCountedObject<RaftClientRequest>> f = 
streamEndOfRequestAsync(requestRef);
       if (f.isCompletedExceptionally()) {
         return f.thenApply(r -> null);
       }
       // the message stream has ended and the request become a WRITE request
-      return replyFuture(requestRef.delegate(f.join()));
+      return replyFuture(f.join());
     }
 
     return role.getLeaderState()
-        .map(ls -> ls.streamAsync(request))
+        .map(ls -> ls.streamAsync(requestRef))
         .orElseGet(() -> CompletableFuture.completedFuture(
             newExceptionReply(request, generateNotLeaderException())));
   }
 
-  private CompletableFuture<RaftClientRequest> 
streamEndOfRequestAsync(RaftClientRequest request) {
+  private CompletableFuture<ReferenceCountedObject<RaftClientRequest>> 
streamEndOfRequestAsync(
+      ReferenceCountedObject<RaftClientRequest> request) {
     return role.getLeaderState()
         .map(ls -> ls.streamEndOfRequestAsync(request))
         .orElse(null);
diff --git 
a/ratis-server/src/test/java/org/apache/ratis/statemachine/impl/SimpleStateMachine4Testing.java
 
b/ratis-server/src/test/java/org/apache/ratis/statemachine/impl/SimpleStateMachine4Testing.java
index 312c9508d..07073be52 100644
--- 
a/ratis-server/src/test/java/org/apache/ratis/statemachine/impl/SimpleStateMachine4Testing.java
+++ 
b/ratis-server/src/test/java/org/apache/ratis/statemachine/impl/SimpleStateMachine4Testing.java
@@ -328,7 +328,7 @@ public class SimpleStateMachine4Testing extends 
BaseStateMachine {
     final String string = request.getContent().toStringUtf8();
     Exception exception;
     try {
-      LOG.info("query " + string);
+      LOG.info("query {}, all available: {}", string, dataMap.keySet());
       final LogEntryProto entry = dataMap.get(string);
       if (entry != null) {
         return 
CompletableFuture.completedFuture(Message.valueOf(entry.toByteString()));
diff --git 
a/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcZeroCopyTestServer.java
 
b/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcZeroCopyTestServer.java
index e1bfe4e22..21db98d4c 100644
--- 
a/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcZeroCopyTestServer.java
+++ 
b/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcZeroCopyTestServer.java
@@ -75,12 +75,14 @@ class GrpcZeroCopyTestServer implements Closeable {
 
   private final Count zeroCopyCount = new Count();
   private final Count nonZeroCopyCount = new Count();
+  private final Count releasedCount = new Count();
 
   private final Server server;
   private final ZeroCopyMessageMarshaller<BinaryRequest> marshaller = new 
ZeroCopyMessageMarshaller<>(
       BinaryRequest.getDefaultInstance(),
       zeroCopyCount::inc,
-      nonZeroCopyCount::inc);
+      nonZeroCopyCount::inc,
+      releasedCount::inc);
 
   GrpcZeroCopyTestServer(int port) {
     final GreeterImpl greeter = new GreeterImpl();


Reply via email to