This is an automated email from the ASF dual-hosted git repository.
szetszwo pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ratis.git
The following commit(s) were added to refs/heads/master by this push:
new 06927826b RATIS-1939. Add TestGrpcZeroCopy. (#971)
06927826b is described below
commit 06927826b15e1c91628b6897b8638ef5b96330e2
Author: Tsz-Wo Nicholas Sze <[email protected]>
AuthorDate: Tue Nov 28 17:49:16 2023 -0800
RATIS-1939. Add TestGrpcZeroCopy. (#971)
---
.../apache/ratis/util/TraditionalBinaryPrefix.java | 5 +
.../ratis/grpc/util/ZeroCopyMessageMarshaller.java | 182 +++++++++++++------
.../ratis/grpc/util/ZeroCopyReadinessChecker.java | 74 --------
ratis-proto/src/main/proto/Test.proto | 12 ++
.../ratis/grpc/util/GrpcZeroCopyTestClient.java | 160 +++++++++++++++++
.../ratis/grpc/util/GrpcZeroCopyTestServer.java | 198 +++++++++++++++++++++
.../apache/ratis/grpc/util/TestGrpcZeroCopy.java | 173 ++++++++++++++++++
7 files changed, 676 insertions(+), 128 deletions(-)
diff --git
a/ratis-common/src/main/java/org/apache/ratis/util/TraditionalBinaryPrefix.java
b/ratis-common/src/main/java/org/apache/ratis/util/TraditionalBinaryPrefix.java
index 32697685e..b1902cfad 100644
---
a/ratis-common/src/main/java/org/apache/ratis/util/TraditionalBinaryPrefix.java
+++
b/ratis-common/src/main/java/org/apache/ratis/util/TraditionalBinaryPrefix.java
@@ -161,4 +161,9 @@ public enum TraditionalBinaryPrefix {
return b.append(' ').append(prefix.symbol).append(unit).toString();
}
}
+
+ /** The same as long2String(n, "", 3). */
+ public static String long2String(long n) {
+ return long2String(n, "", 3);
+ }
}
diff --git
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java
index a47920fc5..f415fb006 100644
---
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java
+++
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java
@@ -29,6 +29,10 @@ import org.apache.ratis.thirdparty.io.grpc.KnownLength;
import
org.apache.ratis.thirdparty.io.grpc.MethodDescriptor.PrototypeMarshaller;
import org.apache.ratis.thirdparty.io.grpc.Status;
import org.apache.ratis.thirdparty.io.grpc.protobuf.lite.ProtoLiteUtils;
+import org.apache.ratis.util.JavaUtils;
+import org.apache.ratis.util.Preconditions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.io.InputStream;
@@ -38,6 +42,8 @@ import java.util.IdentityHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
+import java.util.function.Consumer;
/**
* Custom gRPC marshaller to use zero memory copy feature of gRPC when
deserializing messages. This
@@ -47,14 +53,29 @@ import java.util.Map;
* close it when it's no longer needed. Otherwise, it'd cause memory leak.
*/
public class ZeroCopyMessageMarshaller<T extends MessageLite> implements
PrototypeMarshaller<T> {
- private Map<T, InputStream> unclosedStreams =
- Collections.synchronizedMap(new IdentityHashMap<>());
+ static final Logger LOG =
LoggerFactory.getLogger(ZeroCopyMessageMarshaller.class);
+
+ private final String name;
+ private final Map<T, InputStream> unclosedStreams =
Collections.synchronizedMap(new IdentityHashMap<>());
private final Parser<T> parser;
private final PrototypeMarshaller<T> marshaller;
+ private final Consumer<T> zeroCopyCount;
+ private final Consumer<T> nonZeroCopyCount;
+
public ZeroCopyMessageMarshaller(T defaultInstance) {
- parser = (Parser<T>) defaultInstance.getParserForType();
- marshaller = (PrototypeMarshaller<T>)
ProtoLiteUtils.marshaller(defaultInstance);
+ this(defaultInstance, m -> {}, m -> {});
+ }
+
+ public ZeroCopyMessageMarshaller(T defaultInstance, Consumer<T>
zeroCopyCount, Consumer<T> nonZeroCopyCount) {
+ this.name = JavaUtils.getClassSimpleName(defaultInstance.getClass()) +
"-Marshaller";
+ @SuppressWarnings("unchecked")
+ final Parser<T> p = (Parser<T>) defaultInstance.getParserForType();
+ this.parser = p;
+ this.marshaller = (PrototypeMarshaller<T>)
ProtoLiteUtils.marshaller(defaultInstance);
+
+ this.zeroCopyCount = zeroCopyCount;
+ this.nonZeroCopyCount = nonZeroCopyCount;
}
@Override
@@ -74,55 +95,116 @@ public class ZeroCopyMessageMarshaller<T extends
MessageLite> implements Prototy
@Override
public T parse(InputStream stream) {
+ final T message;
try {
- if (stream instanceof KnownLength
- && stream instanceof Detachable
- && stream instanceof HasByteBuffer
- && ((HasByteBuffer) stream).byteBufferSupported()) {
- int size = stream.available();
- // Stream is now detached here and should be closed later.
- InputStream detachedStream = ((Detachable) stream).detach();
- try {
- // This mark call is to keep buffer while traversing buffers using
skip.
- detachedStream.mark(size);
- List<ByteString> byteStrings = new LinkedList<>();
- while (detachedStream.available() != 0) {
- ByteBuffer buffer = ((HasByteBuffer)
detachedStream).getByteBuffer();
- byteStrings.add(UnsafeByteOperations.unsafeWrap(buffer));
- detachedStream.skip(buffer.remaining());
- }
- detachedStream.reset();
- CodedInputStream codedInputStream =
ByteString.copyFrom(byteStrings).newCodedInput();
- codedInputStream.enableAliasing(true);
- codedInputStream.setSizeLimit(Integer.MAX_VALUE);
- // fast path (no memory copy)
- T message;
- try {
- message = parseFrom(codedInputStream);
- } catch (InvalidProtocolBufferException ipbe) {
- throw Status.INTERNAL
- .withDescription("Invalid protobuf byte sequence")
- .withCause(ipbe)
- .asRuntimeException();
- }
- unclosedStreams.put(message, detachedStream);
- detachedStream = null;
- return message;
- } finally {
- if (detachedStream != null) {
- detachedStream.close();
- }
- }
- }
+ // fast path (no memory copy)
+ message = parseZeroCopy(stream);
} catch (IOException e) {
- throw new RuntimeException(e);
+ throw Status.INTERNAL
+ .withDescription("Failed to parseZeroCopy")
+ .withCause(e)
+ .asRuntimeException();
}
+ if (message != null) {
+ zeroCopyCount.accept(message);
+ return message;
+ }
+
// slow path
- return marshaller.parse(stream);
+ final T copied = marshaller.parse(stream);
+ nonZeroCopyCount.accept(copied);
+ return copied;
+ }
+
+ /** Release the underlying buffers in the given message. */
+ public void release(T message) {
+ final InputStream stream = unclosedStreams.remove(message);
+ if (stream == null) {
+ return;
+ }
+ try {
+ stream.close();
+ } catch (IOException e) {
+ LOG.error(name + ": Failed to close stream.", e);
+ }
+ }
+
+ private List<ByteString> getByteStrings(InputStream detached, int exactSize)
throws IOException {
+ Preconditions.assertTrue(detached instanceof HasByteBuffer);
+
+ // This mark call is to keep buffer while traversing buffers using skip.
+ detached.mark(exactSize);
+ final List<ByteString> byteStrings = new LinkedList<>();
+ while (detached.available() != 0) {
+ final ByteBuffer buffer = ((HasByteBuffer)detached).getByteBuffer();
+ Objects.requireNonNull(buffer, "buffer == null");
+ byteStrings.add(UnsafeByteOperations.unsafeWrap(buffer));
+ final int remaining = buffer.remaining();
+ final long skipped = detached.skip(buffer.remaining());
+ Preconditions.assertSame(remaining, skipped, "skipped");
+ }
+ detached.reset();
+ return byteStrings;
+ }
+
+ /**
+ * Use a zero copy method to parse a message from the given stream.
+ *
+ * @return the parsed message if the given stream support zero copy;
otherwise, return null.
+ */
+ private T parseZeroCopy(InputStream stream) throws IOException {
+ if (!(stream instanceof KnownLength)) {
+ LOG.debug("stream is not KnownLength: {}", stream.getClass());
+ return null;
+ }
+ if (!(stream instanceof Detachable)) {
+ LOG.debug("stream is not Detachable: {}", stream.getClass());
+ return null;
+ }
+ if (!(stream instanceof HasByteBuffer)) {
+ LOG.debug("stream is not HasByteBuffer: {}", stream.getClass());
+ return null;
+ }
+ if (!((HasByteBuffer) stream).byteBufferSupported()) {
+ LOG.debug("stream is HasByteBuffer but not byteBufferSupported: {}",
stream.getClass());
+ return null;
+ }
+
+ final int exactSize = stream.available();
+ InputStream detached = ((Detachable) stream).detach();
+ try {
+ final List<ByteString> byteStrings = getByteStrings(detached, exactSize);
+ final T message = parseFrom(byteStrings, exactSize);
+
+ final InputStream previous = unclosedStreams.put(message, detached);
+ Preconditions.assertNull(previous, "previous");
+
+ detached = null;
+ return message;
+ } finally {
+ if (detached != null) {
+ detached.close();
+ }
+ }
+ }
+
+ private T parseFrom(List<ByteString> byteStrings, int exactSize) {
+ final CodedInputStream codedInputStream =
ByteString.copyFrom(byteStrings).newCodedInput();
+ codedInputStream.enableAliasing(true);
+ codedInputStream.setSizeLimit(exactSize);
+
+ try {
+ return parseFrom(codedInputStream);
+ } catch (InvalidProtocolBufferException e) {
+ throw Status.INTERNAL
+ .withDescription("Invalid protobuf byte sequence")
+ .withCause(e)
+ .asRuntimeException();
+ }
}
private T parseFrom(CodedInputStream stream) throws
InvalidProtocolBufferException {
- T message = parser.parseFrom(stream);
+ final T message = parser.parseFrom(stream);
try {
stream.checkLastTagWas(0);
return message;
@@ -131,12 +213,4 @@ public class ZeroCopyMessageMarshaller<T extends
MessageLite> implements Prototy
throw e;
}
}
-
- /**
- * Application needs to call this function to get the stream for the message
and
- * call stream.close() function to return it to the pool.
- */
- public InputStream popStream(T message) {
- return unclosedStreams.remove(message);
- }
}
\ No newline at end of file
diff --git
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyReadinessChecker.java
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyReadinessChecker.java
deleted file mode 100644
index 5a20d83e9..000000000
---
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyReadinessChecker.java
+++ /dev/null
@@ -1,74 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.ratis.grpc.util;
-
-import org.apache.ratis.thirdparty.com.google.protobuf.MessageLite;
-import org.apache.ratis.thirdparty.io.grpc.KnownLength;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-/**
- * Checker to test whether a zero-copy masharller is available from the
versions of gRPC and
- * Protobuf.
- */
-public final class ZeroCopyReadinessChecker {
- static final Logger LOG =
LoggerFactory.getLogger(ZeroCopyReadinessChecker.class);
- private static final boolean IS_ZERO_COPY_READY;
-
- private ZeroCopyReadinessChecker() {
- }
-
- static {
- // Check whether io.grpc.Detachable exists?
- boolean detachableClassExists = false;
- try {
- // Try to load Detachable interface in the package where KnownLength is
in.
- // This can be done directly by looking up io.grpc.Detachable but rather
- // done indirectly to handle the case where gRPC is being shaded in a
- // different package.
- String knownLengthClassName = KnownLength.class.getName();
- String detachableClassName =
- knownLengthClassName.substring(0,
knownLengthClassName.lastIndexOf('.') + 1)
- + "Detachable";
- // check if class exists.
- Class.forName(detachableClassName);
- detachableClassExists = true;
- } catch (ClassNotFoundException ex) {
- LOG.debug("io.grpc.Detachable not found", ex);
- }
- // Check whether com.google.protobuf.UnsafeByteOperations exists?
- boolean unsafeByteOperationsClassExists = false;
- try {
- // Same above
- String messageLiteClassName = MessageLite.class.getName();
- String unsafeByteOperationsClassName =
- messageLiteClassName.substring(0,
messageLiteClassName.lastIndexOf('.') + 1)
- + "UnsafeByteOperations";
- // check if class exists.
- Class.forName(unsafeByteOperationsClassName);
- unsafeByteOperationsClassExists = true;
- } catch (ClassNotFoundException ex) {
- LOG.debug("com.google.protobuf.UnsafeByteOperations not found", ex);
- }
- IS_ZERO_COPY_READY = detachableClassExists &&
unsafeByteOperationsClassExists;
- }
-
- public static boolean isReady() {
- return IS_ZERO_COPY_READY;
- }
-}
diff --git a/ratis-proto/src/main/proto/Test.proto
b/ratis-proto/src/main/proto/Test.proto
index 8d5769ff3..060f03daa 100644
--- a/ratis-proto/src/main/proto/Test.proto
+++ b/ratis-proto/src/main/proto/Test.proto
@@ -26,6 +26,9 @@ package org.apache.ratis.test;
service Greeter {
rpc Hello (stream HelloRequest)
returns (stream HelloReply) {}
+
+ rpc Binary (stream BinaryRequest)
+ returns (stream BinaryReply) {}
}
message HelloRequest {
@@ -35,3 +38,12 @@ message HelloRequest {
message HelloReply {
string message = 1;
}
+
+message BinaryRequest {
+ bytes data = 1;
+}
+
+message BinaryReply {
+ bytes data = 1;
+}
+
diff --git
a/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcZeroCopyTestClient.java
b/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcZeroCopyTestClient.java
new file mode 100644
index 000000000..791d5a6d2
--- /dev/null
+++
b/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcZeroCopyTestClient.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ratis.grpc.util;
+
+import org.apache.ratis.test.proto.BinaryReply;
+import org.apache.ratis.test.proto.BinaryRequest;
+import org.apache.ratis.test.proto.GreeterGrpc;
+import org.apache.ratis.test.proto.GreeterGrpc.GreeterStub;
+import org.apache.ratis.test.proto.HelloReply;
+import org.apache.ratis.test.proto.HelloRequest;
+import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
+import org.apache.ratis.thirdparty.com.google.protobuf.UnsafeByteOperations;
+import org.apache.ratis.thirdparty.io.grpc.ManagedChannel;
+import org.apache.ratis.thirdparty.io.grpc.ManagedChannelBuilder;
+import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver;
+import org.apache.ratis.util.IOUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Queue;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.TimeUnit;
+import java.util.function.BiFunction;
+
+/** gRPC client for testing */
+class GrpcZeroCopyTestClient implements Closeable {
+ private static final Logger LOG =
LoggerFactory.getLogger(GrpcZeroCopyTestClient.class);
+
+ @FunctionalInterface
+ interface StreamObserverFactory
+ extends BiFunction<GreeterStub, StreamObserver<HelloReply>,
StreamObserver<HelloRequest>> {
+ }
+
+ private final ManagedChannel channel;
+
+ private final StreamObserver<HelloRequest> helloRequestHandler;
+ private final Queue<CompletableFuture<String>> helloReplies = new
ConcurrentLinkedQueue<>();
+
+ private final StreamObserver<BinaryRequest> binaryRequestHandler;
+ private final Queue<CompletableFuture<ByteString>> binaryReplies = new
ConcurrentLinkedQueue<>();
+
+ GrpcZeroCopyTestClient(String host, int port) {
+ this.channel = ManagedChannelBuilder.forAddress(host, port)
+ .usePlaintext()
+ .build();
+ final GreeterStub asyncStub = GreeterGrpc.newStub(channel);
+
+ final StreamObserver<HelloReply> helloResponseHandler = new
StreamObserver<HelloReply>() {
+ @Override
+ public void onNext(HelloReply helloReply) {
+ helloReplies.poll().complete(helloReply.getMessage());
+ }
+
+ @Override
+ public void onError(Throwable throwable) {
+ LOG.info("onError", throwable);
+ completeExceptionally(throwable);
+ }
+
+ @Override
+ public void onCompleted() {
+ LOG.info("onCompleted");
+ completeExceptionally(new IllegalStateException("onCompleted"));
+ }
+
+ void completeExceptionally(Throwable throwable) {
+ helloReplies.forEach(f -> f.completeExceptionally(throwable));
+ helloReplies.clear();
+ }
+ };
+
+ this.helloRequestHandler = asyncStub.hello(helloResponseHandler);
+
+ final StreamObserver<BinaryReply> binaryResponseHandler = new
StreamObserver<BinaryReply>() {
+ @Override
+ public void onNext(BinaryReply binaryReply) {
+ binaryReplies.poll().complete(binaryReply.getData());
+ }
+
+ @Override
+ public void onError(Throwable throwable) {
+ LOG.info("onError", throwable);
+ completeExceptionally(throwable);
+ }
+
+ @Override
+ public void onCompleted() {
+ LOG.info("onCompleted");
+ completeExceptionally(new IllegalStateException("onCompleted"));
+ }
+
+ void completeExceptionally(Throwable throwable) {
+ binaryReplies.forEach(f -> f.completeExceptionally(throwable));
+ binaryReplies.clear();
+ }
+ };
+ this.binaryRequestHandler = asyncStub.binary(binaryResponseHandler);
+ }
+
+ @Override
+ public void close() throws IOException {
+ try {
+ /* After the request handler is cancelled, no more life-cycle hooks are
allowed,
+ * see {@link
org.apache.ratis.thirdparty.io.grpc.ClientCall.Listener#cancel(String,
Throwable)} */
+ // requestHandler.onCompleted();
+ channel.shutdown().awaitTermination(5, TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw IOUtils.toInterruptedIOException("Failed to close", e);
+ }
+ }
+
+ CompletableFuture<String> send(String name) {
+ LOG.info("send message {}", name);
+ final HelloRequest request =
HelloRequest.newBuilder().setName(name).build();
+ final CompletableFuture<String> f = new CompletableFuture<>();
+ try {
+ helloRequestHandler.onNext(request);
+ helloReplies.offer(f);
+ } catch (IllegalStateException e) {
+ // already closed
+ f.completeExceptionally(e);
+ }
+ return f;
+ }
+
+ CompletableFuture<ByteString> send(ByteBuffer data) {
+ LOG.info("send data: size={}, direct? {}", data.remaining(),
data.isDirect());
+ final BinaryRequest request =
BinaryRequest.newBuilder().setData(UnsafeByteOperations.unsafeWrap(data)).build();
+ final CompletableFuture<ByteString> f = new CompletableFuture<>();
+ try {
+ binaryRequestHandler.onNext(request);
+ binaryReplies.offer(f);
+ } catch (IllegalStateException e) {
+ // already closed
+ f.completeExceptionally(e);
+ }
+ return f;
+ }
+
+}
\ No newline at end of file
diff --git
a/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcZeroCopyTestServer.java
b/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcZeroCopyTestServer.java
new file mode 100644
index 000000000..e1bfe4e22
--- /dev/null
+++
b/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcZeroCopyTestServer.java
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ratis.grpc.util;
+
+import org.apache.ratis.test.proto.BinaryReply;
+import org.apache.ratis.test.proto.BinaryRequest;
+import org.apache.ratis.test.proto.GreeterGrpc;
+import org.apache.ratis.test.proto.HelloReply;
+import org.apache.ratis.test.proto.HelloRequest;
+import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
+import org.apache.ratis.thirdparty.com.google.protobuf.UnsafeByteOperations;
+import org.apache.ratis.thirdparty.io.grpc.MethodDescriptor;
+import org.apache.ratis.thirdparty.io.grpc.Server;
+import org.apache.ratis.thirdparty.io.grpc.ServerBuilder;
+import org.apache.ratis.thirdparty.io.grpc.ServerMethodDefinition;
+import org.apache.ratis.thirdparty.io.grpc.ServerServiceDefinition;
+import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver;
+import org.apache.ratis.util.IOUtils;
+import org.apache.ratis.util.TraditionalBinaryPrefix;
+import org.junit.Assert;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+
+/** gRPC server for testing */
+class GrpcZeroCopyTestServer implements Closeable {
+ private static final Logger LOG =
LoggerFactory.getLogger(GrpcZeroCopyTestServer.class);
+
+ static class Count {
+ private int numElements;
+ private long numBytes;
+
+ synchronized int getNumElements() {
+ return numElements;
+ }
+
+ synchronized long getNumBytes() {
+ return numBytes;
+ }
+
+ synchronized void inc(ByteString data) {
+ numElements++;
+ numBytes += data.size();
+ }
+
+ void inc(BinaryRequest request) {
+ inc(request.getData());
+ }
+
+ @Override
+ public synchronized String toString() {
+ return numElements + ", " +
TraditionalBinaryPrefix.long2String(numBytes) + "B";
+ }
+ }
+
+ private final Count zeroCopyCount = new Count();
+ private final Count nonZeroCopyCount = new Count();
+
+ private final Server server;
+ private final ZeroCopyMessageMarshaller<BinaryRequest> marshaller = new
ZeroCopyMessageMarshaller<>(
+ BinaryRequest.getDefaultInstance(),
+ zeroCopyCount::inc,
+ nonZeroCopyCount::inc);
+
+ GrpcZeroCopyTestServer(int port) {
+ final GreeterImpl greeter = new GreeterImpl();
+ final MethodDescriptor<BinaryRequest, BinaryReply> binary =
GreeterGrpc.getBinaryMethod();
+ final String binaryFullMethodName = binary.getFullMethodName();
+ final ServerServiceDefinition service = greeter.bindService();
+ @SuppressWarnings("unchecked")
+ final ServerMethodDefinition<BinaryRequest, BinaryReply> method
+ = (ServerMethodDefinition<BinaryRequest, BinaryReply>)
service.getMethod(binaryFullMethodName);
+ final ServerServiceDefinition.Builder builder =
ServerServiceDefinition.builder(
+ service.getServiceDescriptor().getName());
+
builder.addMethod(binary.toBuilder().setRequestMarshaller(marshaller).build(),
method.getServerCallHandler());
+
+ service.getMethods().stream()
+ .filter(m ->
!m.getMethodDescriptor().getFullMethodName().equals(binaryFullMethodName))
+ .forEach(builder::addMethod);
+
+ this.server = ServerBuilder.forPort(port)
+ .maxInboundMessageSize(Integer.MAX_VALUE)
+ .addService(builder.build())
+ .build();
+ }
+
+ Count getZeroCopyCount() {
+ return zeroCopyCount;
+ }
+
+ Count getNonZeroCopyCount() {
+ return nonZeroCopyCount;
+ }
+
+ void assertCounts(int expectNumElements, long expectNumBytes) {
+ LOG.info("ZeroCopyCount = {}", zeroCopyCount);
+ LOG.info("nonZeroCopyCount = {}", nonZeroCopyCount);
+ Assert.assertEquals("zeroCopyCount.getNumElements()", expectNumElements,
zeroCopyCount.getNumElements());
+ Assert.assertEquals("zeroCopyCount.getNumBytes()", expectNumBytes,
zeroCopyCount.getNumBytes());
+ Assert.assertEquals("nonZeroCopyCount.getNumElements()", 0,
nonZeroCopyCount.getNumElements());
+ Assert.assertEquals("nonZeroCopyCount.getNumBytes()", 0,
nonZeroCopyCount.getNumBytes());
+ }
+
+ int start() throws IOException {
+ server.start();
+ return server.getPort();
+ }
+
+ @Override
+ public void close() throws IOException {
+ try {
+ server.shutdown().awaitTermination(5, TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw IOUtils.toInterruptedIOException("Failed to close", e);
+ }
+ }
+
+ static String toReply(int i, String request) {
+ return i + ") hi " + request;
+ }
+
+ class GreeterImpl extends GreeterGrpc.GreeterImplBase {
+ @Override
+ public StreamObserver<HelloRequest> hello(StreamObserver<HelloReply>
responseObserver) {
+ final AtomicInteger count = new AtomicInteger();
+ return new StreamObserver<HelloRequest>() {
+ @Override
+ public void onNext(HelloRequest request) {
+ final String reply = toReply(count.getAndIncrement(),
request.getName());
+ LOG.info("reply {}", reply);
+
responseObserver.onNext(HelloReply.newBuilder().setMessage(reply).build());
+ }
+
+ @Override
+ public void onError(Throwable throwable) {
+ LOG.error("onError", throwable);
+ }
+
+ @Override
+ public void onCompleted() {
+ responseObserver.onCompleted();
+ }
+ };
+ }
+
+ @Override
+ public StreamObserver<BinaryRequest> binary(StreamObserver<BinaryReply>
responseObserver) {
+ final AtomicInteger count = new AtomicInteger();
+ return new StreamObserver<BinaryRequest>() {
+ @Override
+ public void onNext(BinaryRequest request) {
+ try {
+ final ByteString data = request.getData();
+ int i = count.getAndIncrement();
+ LOG.info("Received {}) data.size() = {}", i, data.size());
+ TestGrpcZeroCopy.RandomData.verify(i, data);
+ final byte[] bytes = new byte[4];
+ ByteBuffer.wrap(bytes).putInt(data.size());
+
responseObserver.onNext(BinaryReply.newBuilder().setData(UnsafeByteOperations.unsafeWrap(bytes)).build());
+ } finally {
+ marshaller.release(request);
+ }
+ }
+
+ @Override
+ public void onError(Throwable throwable) {
+ LOG.error("onError", throwable);
+ }
+
+ @Override
+ public void onCompleted() {
+ responseObserver.onCompleted();
+ }
+ };
+ }
+ }
+}
diff --git
a/ratis-test/src/test/java/org/apache/ratis/grpc/util/TestGrpcZeroCopy.java
b/ratis-test/src/test/java/org/apache/ratis/grpc/util/TestGrpcZeroCopy.java
new file mode 100644
index 000000000..a4592a118
--- /dev/null
+++ b/ratis-test/src/test/java/org/apache/ratis/grpc/util/TestGrpcZeroCopy.java
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ratis.grpc.util;
+
+import org.apache.ratis.BaseTest;
+import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
+import org.apache.ratis.thirdparty.com.google.protobuf.MessageLite;
+import org.apache.ratis.thirdparty.com.google.protobuf.UnsafeByteOperations;
+import org.apache.ratis.thirdparty.io.grpc.KnownLength;
+import org.apache.ratis.thirdparty.io.netty.buffer.ByteBuf;
+import org.apache.ratis.thirdparty.io.netty.buffer.PooledByteBufAllocator;
+import org.apache.ratis.util.NetUtils;
+import org.apache.ratis.util.TraditionalBinaryPrefix;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.net.InetSocketAddress;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.CompletableFuture;
+
+/**
+ * Test gRPC zero-copy feature.
+ */
+public final class TestGrpcZeroCopy extends BaseTest {
+ static class RandomData {
+ private static final Random random = new Random();
+ private static final byte[] array = new byte[4096];
+
+ static void fill(long seed, int size, ByteBuf buf) {
+ random.setSeed(seed);
+ for(int offset = 0; offset < size; ) {
+ final int remaining = Math.min(size - offset, array.length);
+ random.nextBytes(array);
+ buf.writeBytes(array, 0, remaining);
+ offset += remaining;
+ }
+ }
+
+ static void verify(long seed, ByteString b) {
+ random.setSeed(seed);
+ final int size = b.size();
+ for(int offset = 0; offset < size; ) {
+ final int remaining = Math.min(size - offset, array.length);
+ random.nextBytes(array);
+ final ByteString expected = UnsafeByteOperations.unsafeWrap(array, 0,
remaining);
+ final ByteString computed = b.substring(offset, offset + remaining);
+ Assert.assertEquals(expected.size(), computed.size());
+ Assert.assertEquals(expected, computed);
+ offset += remaining;
+ }
+ }
+ }
+
+ private static final boolean IS_ZERO_COPY_READY;
+
+ static {
+ // Check whether the Detachable class exists.
+ boolean detachableClassExists = false;
+ final String detachableClassName =
KnownLength.class.getPackage().getName() + ".Detachable";
+ try {
+ Class.forName(detachableClassName);
+ detachableClassExists = true;
+ } catch (ClassNotFoundException e) {
+ e.printStackTrace(System.out);
+ }
+
+ // Check whether the UnsafeByteOperations exists.
+ boolean unsafeByteOperationsClassExists = false;
+ final String unsafeByteOperationsClassName =
MessageLite.class.getPackage().getName() + ".UnsafeByteOperations";
+ try {
+ Class.forName(unsafeByteOperationsClassName);
+ unsafeByteOperationsClassExists = true;
+ } catch (ClassNotFoundException e) {
+ e.printStackTrace(System.out);
+ }
+ IS_ZERO_COPY_READY = detachableClassExists &&
unsafeByteOperationsClassExists;
+ }
+
+ public static boolean isReady() {
+ return IS_ZERO_COPY_READY;
+ }
+
+ /** Test a zero-copy marshaller is available from the versions of gRPC and
Protobuf. */
+ @Test
+ public void testReadiness() {
+ Assert.assertTrue(isReady());
+ }
+
+
+ @Test
+ public void testZeroCopy() throws Exception {
+ runTestZeroCopy();
+ }
+
+ void runTestZeroCopy() throws Exception {
+ final InetSocketAddress address = NetUtils.createLocalServerAddress();
+
+ try (GrpcZeroCopyTestServer server = new
GrpcZeroCopyTestServer(address.getPort())) {
+ final int port = server.start();
+ try (GrpcZeroCopyTestClient client = new
GrpcZeroCopyTestClient(address.getHostName(), port)) {
+ sendMessages(5, client, server);
+ sendBinaries(11, client, server);
+ }
+ }
+ }
+
+ void sendMessages(int n, GrpcZeroCopyTestClient client,
GrpcZeroCopyTestServer server) throws Exception {
+ final List<String> messages = new ArrayList<>();
+ for (int i = 0; i < n; i++) {
+ messages.add("m" + i);
+ }
+
+ final List<CompletableFuture<String>> futures = new ArrayList<>();
+ for (String m : messages) {
+ futures.add(client.send(m));
+ }
+
+ final int numElements = server.getZeroCopyCount().getNumElements();
+ final long numBytes = server.getZeroCopyCount().getNumBytes();
+ for (int i = 0; i < futures.size(); i++) {
+ final String expected = GrpcZeroCopyTestServer.toReply(i,
messages.get(i));
+ final String reply = futures.get(i).get();
+ Assert.assertEquals("expected = " + expected + " != reply = " + reply,
expected, reply);
+ server.assertCounts(numElements, numBytes);
+ }
+ }
+
+ void sendBinaries(int n, GrpcZeroCopyTestClient client,
GrpcZeroCopyTestServer server) throws Exception {
+ final PooledByteBufAllocator allocator = PooledByteBufAllocator.DEFAULT;
+
+ int numElements = server.getZeroCopyCount().getNumElements();
+ long numBytes = server.getZeroCopyCount().getNumBytes();
+
+ for (int i = 0; i < n; i++) {
+ final int size = 16 << (2 * i);
+ LOG.info("buf {}: {}B", i, TraditionalBinaryPrefix.long2String(size));
+
+ final CompletableFuture<ByteString> future;
+ final ByteBuf buf = allocator.directBuffer(size, size);
+ try {
+ RandomData.fill(i, size, buf);
+ future = client.send(buf.nioBuffer(0, buf.capacity()));
+ } finally {
+ buf.release();
+ }
+
+ final ByteString reply = future.get();
+ Assert.assertEquals(4, reply.size());
+ Assert.assertEquals(size, reply.asReadOnlyByteBuffer().getInt());
+
+ numElements++;
+ numBytes += size;
+ server.assertCounts(numElements, numBytes);
+ }
+ }
+}