This is an automated email from the ASF dual-hosted git repository.

szetszwo pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ratis.git


The following commit(s) were added to refs/heads/master by this push:
     new 06927826b RATIS-1939. Add TestGrpcZeroCopy. (#971)
06927826b is described below

commit 06927826b15e1c91628b6897b8638ef5b96330e2
Author: Tsz-Wo Nicholas Sze <[email protected]>
AuthorDate: Tue Nov 28 17:49:16 2023 -0800

    RATIS-1939. Add TestGrpcZeroCopy. (#971)
---
 .../apache/ratis/util/TraditionalBinaryPrefix.java |   5 +
 .../ratis/grpc/util/ZeroCopyMessageMarshaller.java | 182 +++++++++++++------
 .../ratis/grpc/util/ZeroCopyReadinessChecker.java  |  74 --------
 ratis-proto/src/main/proto/Test.proto              |  12 ++
 .../ratis/grpc/util/GrpcZeroCopyTestClient.java    | 160 +++++++++++++++++
 .../ratis/grpc/util/GrpcZeroCopyTestServer.java    | 198 +++++++++++++++++++++
 .../apache/ratis/grpc/util/TestGrpcZeroCopy.java   | 173 ++++++++++++++++++
 7 files changed, 676 insertions(+), 128 deletions(-)

diff --git 
a/ratis-common/src/main/java/org/apache/ratis/util/TraditionalBinaryPrefix.java 
b/ratis-common/src/main/java/org/apache/ratis/util/TraditionalBinaryPrefix.java
index 32697685e..b1902cfad 100644
--- 
a/ratis-common/src/main/java/org/apache/ratis/util/TraditionalBinaryPrefix.java
+++ 
b/ratis-common/src/main/java/org/apache/ratis/util/TraditionalBinaryPrefix.java
@@ -161,4 +161,9 @@ public enum TraditionalBinaryPrefix {
       return b.append(' ').append(prefix.symbol).append(unit).toString();
     }
   }
+
+  /** The same as long2String(n, "", 3). */
+  public static String long2String(long n) {
+    return long2String(n, "", 3);
+  }
 }
diff --git 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java
 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java
index a47920fc5..f415fb006 100644
--- 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java
+++ 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java
@@ -29,6 +29,10 @@ import org.apache.ratis.thirdparty.io.grpc.KnownLength;
 import 
org.apache.ratis.thirdparty.io.grpc.MethodDescriptor.PrototypeMarshaller;
 import org.apache.ratis.thirdparty.io.grpc.Status;
 import org.apache.ratis.thirdparty.io.grpc.protobuf.lite.ProtoLiteUtils;
+import org.apache.ratis.util.JavaUtils;
+import org.apache.ratis.util.Preconditions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
 import java.io.InputStream;
@@ -38,6 +42,8 @@ import java.util.IdentityHashMap;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
+import java.util.function.Consumer;
 
 /**
  * Custom gRPC marshaller to use zero memory copy feature of gRPC when 
deserializing messages. This
@@ -47,14 +53,29 @@ import java.util.Map;
  * close it when it's no longer needed. Otherwise, it'd cause memory leak.
  */
 public class ZeroCopyMessageMarshaller<T extends MessageLite> implements 
PrototypeMarshaller<T> {
-  private Map<T, InputStream> unclosedStreams =
-      Collections.synchronizedMap(new IdentityHashMap<>());
+  static final Logger LOG = 
LoggerFactory.getLogger(ZeroCopyMessageMarshaller.class);
+
+  private final String name;
+  private final Map<T, InputStream> unclosedStreams = 
Collections.synchronizedMap(new IdentityHashMap<>());
   private final Parser<T> parser;
   private final PrototypeMarshaller<T> marshaller;
 
+  private final Consumer<T> zeroCopyCount;
+  private final Consumer<T> nonZeroCopyCount;
+
   public ZeroCopyMessageMarshaller(T defaultInstance) {
-    parser = (Parser<T>) defaultInstance.getParserForType();
-    marshaller = (PrototypeMarshaller<T>) 
ProtoLiteUtils.marshaller(defaultInstance);
+    this(defaultInstance, m -> {}, m -> {});
+  }
+
+  public ZeroCopyMessageMarshaller(T defaultInstance, Consumer<T> 
zeroCopyCount, Consumer<T> nonZeroCopyCount) {
+    this.name = JavaUtils.getClassSimpleName(defaultInstance.getClass()) + 
"-Marshaller";
+    @SuppressWarnings("unchecked")
+    final Parser<T> p = (Parser<T>) defaultInstance.getParserForType();
+    this.parser = p;
+    this.marshaller = (PrototypeMarshaller<T>) 
ProtoLiteUtils.marshaller(defaultInstance);
+
+    this.zeroCopyCount = zeroCopyCount;
+    this.nonZeroCopyCount = nonZeroCopyCount;
   }
 
   @Override
@@ -74,55 +95,116 @@ public class ZeroCopyMessageMarshaller<T extends 
MessageLite> implements Prototy
 
   @Override
   public T parse(InputStream stream) {
+    final T message;
     try {
-      if (stream instanceof KnownLength
-          && stream instanceof Detachable
-          && stream instanceof HasByteBuffer
-          && ((HasByteBuffer) stream).byteBufferSupported()) {
-        int size = stream.available();
-        // Stream is now detached here and should be closed later.
-        InputStream detachedStream = ((Detachable) stream).detach();
-        try {
-          // This mark call is to keep buffer while traversing buffers using 
skip.
-          detachedStream.mark(size);
-          List<ByteString> byteStrings = new LinkedList<>();
-          while (detachedStream.available() != 0) {
-            ByteBuffer buffer = ((HasByteBuffer) 
detachedStream).getByteBuffer();
-            byteStrings.add(UnsafeByteOperations.unsafeWrap(buffer));
-            detachedStream.skip(buffer.remaining());
-          }
-          detachedStream.reset();
-          CodedInputStream codedInputStream = 
ByteString.copyFrom(byteStrings).newCodedInput();
-          codedInputStream.enableAliasing(true);
-          codedInputStream.setSizeLimit(Integer.MAX_VALUE);
-          // fast path (no memory copy)
-          T message;
-          try {
-            message = parseFrom(codedInputStream);
-          } catch (InvalidProtocolBufferException ipbe) {
-            throw Status.INTERNAL
-                .withDescription("Invalid protobuf byte sequence")
-                .withCause(ipbe)
-                .asRuntimeException();
-          }
-          unclosedStreams.put(message, detachedStream);
-          detachedStream = null;
-          return message;
-        } finally {
-          if (detachedStream != null) {
-            detachedStream.close();
-          }
-        }
-      }
+      // fast path (no memory copy)
+      message = parseZeroCopy(stream);
     } catch (IOException e) {
-      throw new RuntimeException(e);
+      throw Status.INTERNAL
+          .withDescription("Failed to parseZeroCopy")
+          .withCause(e)
+          .asRuntimeException();
     }
+    if (message != null) {
+      zeroCopyCount.accept(message);
+      return message;
+    }
+
     // slow path
-    return marshaller.parse(stream);
+    final T copied = marshaller.parse(stream);
+    nonZeroCopyCount.accept(copied);
+    return copied;
+  }
+
+  /** Release the underlying buffers in the given message. */
+  public void release(T message) {
+    final InputStream stream = unclosedStreams.remove(message);
+    if (stream == null) {
+      return;
+    }
+    try {
+      stream.close();
+    } catch (IOException e) {
+      LOG.error(name + ": Failed to close stream.", e);
+    }
+  }
+
+  private List<ByteString> getByteStrings(InputStream detached, int exactSize) 
throws IOException {
+    Preconditions.assertTrue(detached instanceof HasByteBuffer);
+
+    // This mark call is to keep buffer while traversing buffers using skip.
+    detached.mark(exactSize);
+    final List<ByteString> byteStrings = new LinkedList<>();
+    while (detached.available() != 0) {
+      final ByteBuffer buffer = ((HasByteBuffer)detached).getByteBuffer();
+      Objects.requireNonNull(buffer, "buffer == null");
+      byteStrings.add(UnsafeByteOperations.unsafeWrap(buffer));
+      final int remaining = buffer.remaining();
+      final long skipped = detached.skip(buffer.remaining());
+      Preconditions.assertSame(remaining, skipped, "skipped");
+    }
+    detached.reset();
+    return byteStrings;
+  }
+
+  /**
+   * Use a zero copy method to parse a message from the given stream.
+   *
+   * @return the parsed message if the given stream support zero copy; 
otherwise, return null.
+   */
+  private T parseZeroCopy(InputStream stream) throws IOException {
+    if (!(stream instanceof KnownLength)) {
+      LOG.debug("stream is not KnownLength: {}", stream.getClass());
+      return null;
+    }
+    if (!(stream instanceof Detachable)) {
+      LOG.debug("stream is not Detachable: {}", stream.getClass());
+      return null;
+    }
+    if (!(stream instanceof HasByteBuffer)) {
+      LOG.debug("stream is not HasByteBuffer: {}", stream.getClass());
+      return null;
+    }
+    if (!((HasByteBuffer) stream).byteBufferSupported()) {
+      LOG.debug("stream is HasByteBuffer but not byteBufferSupported: {}", 
stream.getClass());
+      return null;
+    }
+
+    final int exactSize = stream.available();
+    InputStream detached = ((Detachable) stream).detach();
+    try {
+      final List<ByteString> byteStrings = getByteStrings(detached, exactSize);
+      final T message = parseFrom(byteStrings, exactSize);
+
+      final InputStream previous = unclosedStreams.put(message, detached);
+      Preconditions.assertNull(previous, "previous");
+
+      detached = null;
+      return message;
+    } finally {
+      if (detached != null) {
+        detached.close();
+      }
+    }
+  }
+
+  private T parseFrom(List<ByteString> byteStrings, int exactSize) {
+    final CodedInputStream codedInputStream = 
ByteString.copyFrom(byteStrings).newCodedInput();
+    codedInputStream.enableAliasing(true);
+    codedInputStream.setSizeLimit(exactSize);
+
+    try {
+      return parseFrom(codedInputStream);
+    } catch (InvalidProtocolBufferException e) {
+      throw Status.INTERNAL
+          .withDescription("Invalid protobuf byte sequence")
+          .withCause(e)
+          .asRuntimeException();
+    }
   }
 
   private T parseFrom(CodedInputStream stream) throws 
InvalidProtocolBufferException {
-    T message = parser.parseFrom(stream);
+    final T message = parser.parseFrom(stream);
     try {
       stream.checkLastTagWas(0);
       return message;
@@ -131,12 +213,4 @@ public class ZeroCopyMessageMarshaller<T extends 
MessageLite> implements Prototy
       throw e;
     }
   }
-
-  /**
-   * Application needs to call this function to get the stream for the message 
and
-   * call stream.close() function to return it to the pool.
-   */
-  public InputStream popStream(T message) {
-    return unclosedStreams.remove(message);
-  }
 }
\ No newline at end of file
diff --git 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyReadinessChecker.java
 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyReadinessChecker.java
deleted file mode 100644
index 5a20d83e9..000000000
--- 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyReadinessChecker.java
+++ /dev/null
@@ -1,74 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.ratis.grpc.util;
-
-import org.apache.ratis.thirdparty.com.google.protobuf.MessageLite;
-import org.apache.ratis.thirdparty.io.grpc.KnownLength;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-/**
- * Checker to test whether a zero-copy masharller is available from the 
versions of gRPC and
- * Protobuf.
- */
-public final class ZeroCopyReadinessChecker {
-  static final Logger LOG = 
LoggerFactory.getLogger(ZeroCopyReadinessChecker.class);
-  private static final boolean IS_ZERO_COPY_READY;
-
-  private ZeroCopyReadinessChecker() {
-  }
-
-  static {
-    // Check whether io.grpc.Detachable exists?
-    boolean detachableClassExists = false;
-    try {
-      // Try to load Detachable interface in the package where KnownLength is 
in.
-      // This can be done directly by looking up io.grpc.Detachable but rather
-      // done indirectly to handle the case where gRPC is being shaded in a
-      // different package.
-      String knownLengthClassName = KnownLength.class.getName();
-      String detachableClassName =
-          knownLengthClassName.substring(0, 
knownLengthClassName.lastIndexOf('.') + 1)
-              + "Detachable";
-      // check if class exists.
-      Class.forName(detachableClassName);
-      detachableClassExists = true;
-    } catch (ClassNotFoundException ex) {
-      LOG.debug("io.grpc.Detachable not found", ex);
-    }
-    // Check whether com.google.protobuf.UnsafeByteOperations exists?
-    boolean unsafeByteOperationsClassExists = false;
-    try {
-      // Same above
-      String messageLiteClassName = MessageLite.class.getName();
-      String unsafeByteOperationsClassName =
-          messageLiteClassName.substring(0, 
messageLiteClassName.lastIndexOf('.') + 1)
-              + "UnsafeByteOperations";
-      // check if class exists.
-      Class.forName(unsafeByteOperationsClassName);
-      unsafeByteOperationsClassExists = true;
-    } catch (ClassNotFoundException ex) {
-      LOG.debug("com.google.protobuf.UnsafeByteOperations not found", ex);
-    }
-    IS_ZERO_COPY_READY = detachableClassExists && 
unsafeByteOperationsClassExists;
-  }
-
-  public static boolean isReady() {
-    return IS_ZERO_COPY_READY;
-  }
-}
diff --git a/ratis-proto/src/main/proto/Test.proto 
b/ratis-proto/src/main/proto/Test.proto
index 8d5769ff3..060f03daa 100644
--- a/ratis-proto/src/main/proto/Test.proto
+++ b/ratis-proto/src/main/proto/Test.proto
@@ -26,6 +26,9 @@ package org.apache.ratis.test;
 service Greeter {
     rpc Hello (stream HelloRequest)
         returns (stream HelloReply) {}
+
+    rpc Binary (stream BinaryRequest)
+        returns (stream BinaryReply) {}
 }
 
 message HelloRequest {
@@ -35,3 +38,12 @@ message HelloRequest {
 message HelloReply {
     string message = 1;
 }
+
+message BinaryRequest {
+    bytes data = 1;
+}
+
+message BinaryReply {
+    bytes data = 1;
+}
+
diff --git 
a/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcZeroCopyTestClient.java
 
b/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcZeroCopyTestClient.java
new file mode 100644
index 000000000..791d5a6d2
--- /dev/null
+++ 
b/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcZeroCopyTestClient.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ratis.grpc.util;
+
+import org.apache.ratis.test.proto.BinaryReply;
+import org.apache.ratis.test.proto.BinaryRequest;
+import org.apache.ratis.test.proto.GreeterGrpc;
+import org.apache.ratis.test.proto.GreeterGrpc.GreeterStub;
+import org.apache.ratis.test.proto.HelloReply;
+import org.apache.ratis.test.proto.HelloRequest;
+import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
+import org.apache.ratis.thirdparty.com.google.protobuf.UnsafeByteOperations;
+import org.apache.ratis.thirdparty.io.grpc.ManagedChannel;
+import org.apache.ratis.thirdparty.io.grpc.ManagedChannelBuilder;
+import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver;
+import org.apache.ratis.util.IOUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Queue;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.TimeUnit;
+import java.util.function.BiFunction;
+
+/** gRPC client for testing */
+class GrpcZeroCopyTestClient implements Closeable {
+  private static final Logger LOG = 
LoggerFactory.getLogger(GrpcZeroCopyTestClient.class);
+
+  @FunctionalInterface
+  interface StreamObserverFactory
+      extends BiFunction<GreeterStub, StreamObserver<HelloReply>, 
StreamObserver<HelloRequest>> {
+  }
+
+  private final ManagedChannel channel;
+
+  private final StreamObserver<HelloRequest> helloRequestHandler;
+  private final Queue<CompletableFuture<String>> helloReplies = new 
ConcurrentLinkedQueue<>();
+
+  private final StreamObserver<BinaryRequest> binaryRequestHandler;
+  private final Queue<CompletableFuture<ByteString>> binaryReplies = new 
ConcurrentLinkedQueue<>();
+
+  GrpcZeroCopyTestClient(String host, int port) {
+    this.channel = ManagedChannelBuilder.forAddress(host, port)
+        .usePlaintext()
+        .build();
+    final GreeterStub asyncStub = GreeterGrpc.newStub(channel);
+
+    final StreamObserver<HelloReply> helloResponseHandler = new 
StreamObserver<HelloReply>() {
+      @Override
+      public void onNext(HelloReply helloReply) {
+        helloReplies.poll().complete(helloReply.getMessage());
+      }
+
+      @Override
+      public void onError(Throwable throwable) {
+        LOG.info("onError", throwable);
+        completeExceptionally(throwable);
+      }
+
+      @Override
+      public void onCompleted() {
+        LOG.info("onCompleted");
+        completeExceptionally(new IllegalStateException("onCompleted"));
+      }
+
+      void completeExceptionally(Throwable throwable) {
+        helloReplies.forEach(f -> f.completeExceptionally(throwable));
+        helloReplies.clear();
+      }
+    };
+
+    this.helloRequestHandler = asyncStub.hello(helloResponseHandler);
+
+    final StreamObserver<BinaryReply> binaryResponseHandler = new 
StreamObserver<BinaryReply>() {
+      @Override
+      public void onNext(BinaryReply binaryReply) {
+        binaryReplies.poll().complete(binaryReply.getData());
+      }
+
+      @Override
+      public void onError(Throwable throwable) {
+        LOG.info("onError", throwable);
+        completeExceptionally(throwable);
+      }
+
+      @Override
+      public void onCompleted() {
+        LOG.info("onCompleted");
+        completeExceptionally(new IllegalStateException("onCompleted"));
+      }
+
+      void completeExceptionally(Throwable throwable) {
+        binaryReplies.forEach(f -> f.completeExceptionally(throwable));
+        binaryReplies.clear();
+      }
+    };
+    this.binaryRequestHandler = asyncStub.binary(binaryResponseHandler);
+  }
+
+  @Override
+  public void close() throws IOException {
+    try {
+      /* After the request handler is cancelled, no more life-cycle hooks are 
allowed,
+       * see {@link 
org.apache.ratis.thirdparty.io.grpc.ClientCall.Listener#cancel(String, 
Throwable)} */
+      // requestHandler.onCompleted();
+      channel.shutdown().awaitTermination(5, TimeUnit.SECONDS);
+    } catch (InterruptedException e) {
+      Thread.currentThread().interrupt();
+      throw IOUtils.toInterruptedIOException("Failed to close", e);
+    }
+  }
+
+  CompletableFuture<String> send(String name) {
+    LOG.info("send message {}", name);
+    final HelloRequest request = 
HelloRequest.newBuilder().setName(name).build();
+    final CompletableFuture<String> f = new CompletableFuture<>();
+    try {
+      helloRequestHandler.onNext(request);
+      helloReplies.offer(f);
+    } catch (IllegalStateException e) {
+      // already closed
+      f.completeExceptionally(e);
+    }
+    return f;
+  }
+
+  CompletableFuture<ByteString> send(ByteBuffer data) {
+    LOG.info("send data: size={}, direct? {}", data.remaining(), 
data.isDirect());
+    final BinaryRequest request = 
BinaryRequest.newBuilder().setData(UnsafeByteOperations.unsafeWrap(data)).build();
+    final CompletableFuture<ByteString> f = new CompletableFuture<>();
+    try {
+      binaryRequestHandler.onNext(request);
+      binaryReplies.offer(f);
+    } catch (IllegalStateException e) {
+      // already closed
+      f.completeExceptionally(e);
+    }
+    return f;
+  }
+
+}
\ No newline at end of file
diff --git 
a/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcZeroCopyTestServer.java
 
b/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcZeroCopyTestServer.java
new file mode 100644
index 000000000..e1bfe4e22
--- /dev/null
+++ 
b/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcZeroCopyTestServer.java
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ratis.grpc.util;
+
+import org.apache.ratis.test.proto.BinaryReply;
+import org.apache.ratis.test.proto.BinaryRequest;
+import org.apache.ratis.test.proto.GreeterGrpc;
+import org.apache.ratis.test.proto.HelloReply;
+import org.apache.ratis.test.proto.HelloRequest;
+import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
+import org.apache.ratis.thirdparty.com.google.protobuf.UnsafeByteOperations;
+import org.apache.ratis.thirdparty.io.grpc.MethodDescriptor;
+import org.apache.ratis.thirdparty.io.grpc.Server;
+import org.apache.ratis.thirdparty.io.grpc.ServerBuilder;
+import org.apache.ratis.thirdparty.io.grpc.ServerMethodDefinition;
+import org.apache.ratis.thirdparty.io.grpc.ServerServiceDefinition;
+import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver;
+import org.apache.ratis.util.IOUtils;
+import org.apache.ratis.util.TraditionalBinaryPrefix;
+import org.junit.Assert;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+
+/** gRPC server for testing */
+class GrpcZeroCopyTestServer implements Closeable {
+  private static final Logger LOG = 
LoggerFactory.getLogger(GrpcZeroCopyTestServer.class);
+
+  static class Count {
+    private int numElements;
+    private long numBytes;
+
+    synchronized int getNumElements() {
+      return numElements;
+    }
+
+    synchronized long getNumBytes() {
+      return numBytes;
+    }
+
+    synchronized void inc(ByteString data) {
+      numElements++;
+      numBytes += data.size();
+    }
+
+    void inc(BinaryRequest request) {
+      inc(request.getData());
+    }
+
+    @Override
+    public synchronized String toString() {
+      return numElements + ", " + 
TraditionalBinaryPrefix.long2String(numBytes) + "B";
+    }
+  }
+
+  private final Count zeroCopyCount = new Count();
+  private final Count nonZeroCopyCount = new Count();
+
+  private final Server server;
+  private final ZeroCopyMessageMarshaller<BinaryRequest> marshaller = new 
ZeroCopyMessageMarshaller<>(
+      BinaryRequest.getDefaultInstance(),
+      zeroCopyCount::inc,
+      nonZeroCopyCount::inc);
+
+  GrpcZeroCopyTestServer(int port) {
+    final GreeterImpl greeter = new GreeterImpl();
+    final MethodDescriptor<BinaryRequest, BinaryReply> binary = 
GreeterGrpc.getBinaryMethod();
+    final String binaryFullMethodName = binary.getFullMethodName();
+    final ServerServiceDefinition service = greeter.bindService();
+    @SuppressWarnings("unchecked")
+    final ServerMethodDefinition<BinaryRequest, BinaryReply> method
+        = (ServerMethodDefinition<BinaryRequest, BinaryReply>) 
service.getMethod(binaryFullMethodName);
+    final ServerServiceDefinition.Builder builder = 
ServerServiceDefinition.builder(
+        service.getServiceDescriptor().getName());
+    
builder.addMethod(binary.toBuilder().setRequestMarshaller(marshaller).build(), 
method.getServerCallHandler());
+
+    service.getMethods().stream()
+        .filter(m -> 
!m.getMethodDescriptor().getFullMethodName().equals(binaryFullMethodName))
+        .forEach(builder::addMethod);
+
+    this.server = ServerBuilder.forPort(port)
+        .maxInboundMessageSize(Integer.MAX_VALUE)
+        .addService(builder.build())
+        .build();
+  }
+
+  Count getZeroCopyCount() {
+    return zeroCopyCount;
+  }
+
+  Count getNonZeroCopyCount() {
+    return nonZeroCopyCount;
+  }
+
+  void assertCounts(int expectNumElements, long expectNumBytes) {
+    LOG.info("ZeroCopyCount    = {}", zeroCopyCount);
+    LOG.info("nonZeroCopyCount = {}", nonZeroCopyCount);
+    Assert.assertEquals("zeroCopyCount.getNumElements()", expectNumElements, 
zeroCopyCount.getNumElements());
+    Assert.assertEquals("zeroCopyCount.getNumBytes()", expectNumBytes, 
zeroCopyCount.getNumBytes());
+    Assert.assertEquals("nonZeroCopyCount.getNumElements()", 0, 
nonZeroCopyCount.getNumElements());
+    Assert.assertEquals("nonZeroCopyCount.getNumBytes()", 0, 
nonZeroCopyCount.getNumBytes());
+  }
+
+  int start() throws IOException {
+    server.start();
+    return server.getPort();
+  }
+
+  @Override
+  public void close() throws IOException {
+    try {
+      server.shutdown().awaitTermination(5, TimeUnit.SECONDS);
+    } catch (InterruptedException e) {
+      Thread.currentThread().interrupt();
+      throw IOUtils.toInterruptedIOException("Failed to close", e);
+    }
+  }
+
+  static String toReply(int i, String request) {
+    return i + ") hi " + request;
+  }
+
+  class GreeterImpl extends GreeterGrpc.GreeterImplBase {
+    @Override
+    public StreamObserver<HelloRequest> hello(StreamObserver<HelloReply> 
responseObserver) {
+      final AtomicInteger count = new AtomicInteger();
+      return new StreamObserver<HelloRequest>() {
+        @Override
+        public void onNext(HelloRequest request) {
+          final String reply = toReply(count.getAndIncrement(), 
request.getName());
+          LOG.info("reply {}", reply);
+          
responseObserver.onNext(HelloReply.newBuilder().setMessage(reply).build());
+        }
+
+        @Override
+        public void onError(Throwable throwable) {
+          LOG.error("onError", throwable);
+        }
+
+        @Override
+        public void onCompleted() {
+          responseObserver.onCompleted();
+        }
+      };
+    }
+
+    @Override
+    public StreamObserver<BinaryRequest> binary(StreamObserver<BinaryReply> 
responseObserver) {
+      final AtomicInteger count = new AtomicInteger();
+      return new StreamObserver<BinaryRequest>() {
+        @Override
+        public void onNext(BinaryRequest request) {
+          try {
+            final ByteString data = request.getData();
+            int i = count.getAndIncrement();
+            LOG.info("Received {}) data.size() = {}", i, data.size());
+            TestGrpcZeroCopy.RandomData.verify(i, data);
+            final byte[] bytes = new byte[4];
+            ByteBuffer.wrap(bytes).putInt(data.size());
+            
responseObserver.onNext(BinaryReply.newBuilder().setData(UnsafeByteOperations.unsafeWrap(bytes)).build());
+          } finally {
+            marshaller.release(request);
+          }
+        }
+
+        @Override
+        public void onError(Throwable throwable) {
+          LOG.error("onError", throwable);
+        }
+
+        @Override
+        public void onCompleted() {
+          responseObserver.onCompleted();
+        }
+      };
+    }
+  }
+}
diff --git 
a/ratis-test/src/test/java/org/apache/ratis/grpc/util/TestGrpcZeroCopy.java 
b/ratis-test/src/test/java/org/apache/ratis/grpc/util/TestGrpcZeroCopy.java
new file mode 100644
index 000000000..a4592a118
--- /dev/null
+++ b/ratis-test/src/test/java/org/apache/ratis/grpc/util/TestGrpcZeroCopy.java
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ratis.grpc.util;
+
+import org.apache.ratis.BaseTest;
+import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
+import org.apache.ratis.thirdparty.com.google.protobuf.MessageLite;
+import org.apache.ratis.thirdparty.com.google.protobuf.UnsafeByteOperations;
+import org.apache.ratis.thirdparty.io.grpc.KnownLength;
+import org.apache.ratis.thirdparty.io.netty.buffer.ByteBuf;
+import org.apache.ratis.thirdparty.io.netty.buffer.PooledByteBufAllocator;
+import org.apache.ratis.util.NetUtils;
+import org.apache.ratis.util.TraditionalBinaryPrefix;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.net.InetSocketAddress;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.CompletableFuture;
+
+/**
+ * Test gRPC zero-copy feature.
+ */
+public final class TestGrpcZeroCopy extends BaseTest {
+  static class RandomData {
+    private static final Random random = new Random();
+    private static final byte[] array = new byte[4096];
+
+    static void fill(long seed, int size, ByteBuf buf) {
+      random.setSeed(seed);
+      for(int offset = 0; offset < size; ) {
+        final int remaining = Math.min(size - offset, array.length);
+        random.nextBytes(array);
+        buf.writeBytes(array, 0, remaining);
+        offset += remaining;
+      }
+    }
+
+    static void verify(long seed, ByteString b) {
+      random.setSeed(seed);
+      final int size = b.size();
+      for(int offset = 0; offset < size; ) {
+        final int remaining = Math.min(size - offset, array.length);
+        random.nextBytes(array);
+        final ByteString expected = UnsafeByteOperations.unsafeWrap(array, 0, 
remaining);
+        final ByteString computed = b.substring(offset, offset + remaining);
+        Assert.assertEquals(expected.size(), computed.size());
+        Assert.assertEquals(expected, computed);
+        offset += remaining;
+      }
+    }
+  }
+
+  private static final boolean IS_ZERO_COPY_READY;
+
+  static {
+    // Check whether the Detachable class exists.
+    boolean detachableClassExists = false;
+    final String detachableClassName = 
KnownLength.class.getPackage().getName() + ".Detachable";
+    try {
+      Class.forName(detachableClassName);
+      detachableClassExists = true;
+    } catch (ClassNotFoundException e) {
+      e.printStackTrace(System.out);
+    }
+
+    // Check whether the UnsafeByteOperations exists.
+    boolean unsafeByteOperationsClassExists = false;
+    final String unsafeByteOperationsClassName = 
MessageLite.class.getPackage().getName() + ".UnsafeByteOperations";
+    try {
+      Class.forName(unsafeByteOperationsClassName);
+      unsafeByteOperationsClassExists = true;
+    } catch (ClassNotFoundException e) {
+      e.printStackTrace(System.out);
+    }
+    IS_ZERO_COPY_READY = detachableClassExists && 
unsafeByteOperationsClassExists;
+  }
+
+  public static boolean isReady() {
+    return IS_ZERO_COPY_READY;
+  }
+
+  /** Test a zero-copy marshaller is available from the versions of gRPC and 
Protobuf. */
+  @Test
+  public void testReadiness() {
+    Assert.assertTrue(isReady());
+  }
+
+
+  @Test
+  public void testZeroCopy() throws Exception {
+    runTestZeroCopy();
+  }
+
+  void runTestZeroCopy() throws Exception {
+    final InetSocketAddress address = NetUtils.createLocalServerAddress();
+
+    try (GrpcZeroCopyTestServer server = new 
GrpcZeroCopyTestServer(address.getPort())) {
+      final int port = server.start();
+      try (GrpcZeroCopyTestClient client = new 
GrpcZeroCopyTestClient(address.getHostName(), port)) {
+        sendMessages(5, client, server);
+        sendBinaries(11, client, server);
+      }
+    }
+  }
+
+  void sendMessages(int n, GrpcZeroCopyTestClient client, 
GrpcZeroCopyTestServer server) throws Exception {
+    final List<String> messages = new ArrayList<>();
+    for (int i = 0; i < n; i++) {
+      messages.add("m" + i);
+    }
+
+    final List<CompletableFuture<String>> futures = new ArrayList<>();
+    for (String m : messages) {
+      futures.add(client.send(m));
+    }
+
+    final int numElements = server.getZeroCopyCount().getNumElements();
+    final long numBytes = server.getZeroCopyCount().getNumBytes();
+    for (int i = 0; i < futures.size(); i++) {
+      final String expected = GrpcZeroCopyTestServer.toReply(i, 
messages.get(i));
+      final String reply = futures.get(i).get();
+      Assert.assertEquals("expected = " + expected + " != reply = " + reply, 
expected, reply);
+      server.assertCounts(numElements, numBytes);
+    }
+  }
+
+  void sendBinaries(int n, GrpcZeroCopyTestClient client, 
GrpcZeroCopyTestServer server) throws Exception {
+    final PooledByteBufAllocator allocator = PooledByteBufAllocator.DEFAULT;
+
+    int numElements = server.getZeroCopyCount().getNumElements();
+    long numBytes = server.getZeroCopyCount().getNumBytes();
+
+    for (int i = 0; i < n; i++) {
+      final int size = 16 << (2 * i);
+      LOG.info("buf {}: {}B", i, TraditionalBinaryPrefix.long2String(size));
+
+      final CompletableFuture<ByteString> future;
+      final ByteBuf buf = allocator.directBuffer(size, size);
+      try {
+        RandomData.fill(i, size, buf);
+        future = client.send(buf.nioBuffer(0, buf.capacity()));
+      } finally {
+        buf.release();
+      }
+
+      final ByteString reply = future.get();
+      Assert.assertEquals(4, reply.size());
+      Assert.assertEquals(size, reply.asReadOnlyByteBuffer().getInt());
+
+      numElements++;
+      numBytes += size;
+      server.assertCounts(numElements, numBytes);
+    }
+  }
+}


Reply via email to