This is an automated email from the ASF dual-hosted git repository.
szetszwo pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ratis.git
The following commit(s) were added to refs/heads/master by this push:
new e96ed1a33 RATIS-2168. Support custom gRPC services. (#1169)
e96ed1a33 is described below
commit e96ed1a33840385446f4e647864a169467da5ab7
Author: Tsz-Wo Nicholas Sze <[email protected]>
AuthorDate: Wed Oct 16 18:06:59 2024 -0700
RATIS-2168. Support custom gRPC services. (#1169)
---
ratis-grpc/pom.xml | 4 +
.../java/org/apache/ratis/grpc/GrpcConfigKeys.java | 29 ++-
.../java/org/apache/ratis/grpc/GrpcFactory.java | 19 +-
.../apache/ratis/grpc/server/GrpcLogAppender.java | 6 +-
.../org/apache/ratis/grpc/server/GrpcServices.java | 56 ++++++
.../{GrpcService.java => GrpcServicesImpl.java} | 45 +++--
.../apache/ratis/grpc/MiniRaftClusterWithGrpc.java | 6 +-
.../apache/ratis/grpc/TestCustomGrpcServices.java | 205 +++++++++++++++++++++
.../grpc/{ => server}/TestGrpcMessageMetrics.java | 8 +-
9 files changed, 339 insertions(+), 39 deletions(-)
diff --git a/ratis-grpc/pom.xml b/ratis-grpc/pom.xml
index 2e542215c..fc6797831 100644
--- a/ratis-grpc/pom.xml
+++ b/ratis-grpc/pom.xml
@@ -53,6 +53,10 @@
<scope>test</scope>
<type>test-jar</type>
</dependency>
+ <dependency>
+ <groupId>org.apache.ratis</groupId>
+ <artifactId>ratis-server-api</artifactId>
+ </dependency>
<dependency>
<artifactId>ratis-server</artifactId>
<groupId>org.apache.ratis</groupId>
diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcConfigKeys.java
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcConfigKeys.java
index c14d844ee..2fcb9b6b0 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcConfigKeys.java
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcConfigKeys.java
@@ -19,6 +19,7 @@ package org.apache.ratis.grpc;
import org.apache.ratis.conf.Parameters;
import org.apache.ratis.conf.RaftProperties;
+import org.apache.ratis.grpc.server.GrpcServices;
import org.apache.ratis.server.RaftServerConfigKeys;
import org.apache.ratis.util.SizeInBytes;
import org.apache.ratis.util.TimeDuration;
@@ -230,15 +231,6 @@ public interface GrpcConfigKeys {
setInt(properties::setInt, ASYNC_REQUEST_THREAD_POOL_SIZE_KEY, port);
}
- String TLS_CONF_PARAMETER = PREFIX + ".tls.conf";
- Class<GrpcTlsConfig> TLS_CONF_CLASS = TLS.CONF_CLASS;
- static GrpcTlsConfig tlsConf(Parameters parameters) {
- return parameters != null ? parameters.get(TLS_CONF_PARAMETER,
TLS_CONF_CLASS): null;
- }
- static void setTlsConf(Parameters parameters, GrpcTlsConfig conf) {
- parameters.put(TLS_CONF_PARAMETER, conf, TLS_CONF_CLASS);
- }
-
String LEADER_OUTSTANDING_APPENDS_MAX_KEY = PREFIX +
".leader.outstanding.appends.max";
int LEADER_OUTSTANDING_APPENDS_MAX_DEFAULT = 8;
static int leaderOutstandingAppendsMax(RaftProperties properties) {
@@ -301,6 +293,25 @@ public interface GrpcConfigKeys {
static void setZeroCopyEnabled(RaftProperties properties, boolean enabled)
{
setBoolean(properties::setBoolean, ZERO_COPY_ENABLED_KEY, enabled);
}
+
+ String SERVICES_CUSTOMIZER_PARAMETER = PREFIX + ".services.customizer";
+ Class<GrpcServices.Customizer> SERVICES_CUSTOMIZER_CLASS =
GrpcServices.Customizer.class;
+ static GrpcServices.Customizer servicesCustomizer(Parameters parameters) {
+ return parameters == null ? null
+ : parameters.get(SERVICES_CUSTOMIZER_PARAMETER,
SERVICES_CUSTOMIZER_CLASS);
+ }
+ static void setServicesCustomizer(Parameters parameters,
GrpcServices.Customizer customizer) {
+ parameters.put(SERVICES_CUSTOMIZER_PARAMETER, customizer,
SERVICES_CUSTOMIZER_CLASS);
+ }
+
+ String TLS_CONF_PARAMETER = PREFIX + ".tls.conf";
+ Class<GrpcTlsConfig> TLS_CONF_CLASS = TLS.CONF_CLASS;
+ static GrpcTlsConfig tlsConf(Parameters parameters) {
+ return parameters != null ? parameters.get(TLS_CONF_PARAMETER,
TLS_CONF_CLASS): null;
+ }
+ static void setTlsConf(Parameters parameters, GrpcTlsConfig conf) {
+ parameters.put(TLS_CONF_PARAMETER, conf, TLS_CONF_CLASS);
+ }
}
String MESSAGE_SIZE_MAX_KEY = PREFIX + ".message.size.max";
diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcFactory.java
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcFactory.java
index 75eb34a2d..331d1a858 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcFactory.java
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcFactory.java
@@ -22,7 +22,8 @@ import org.apache.ratis.conf.Parameters;
import org.apache.ratis.conf.RaftProperties;
import org.apache.ratis.grpc.client.GrpcClientRpc;
import org.apache.ratis.grpc.server.GrpcLogAppender;
-import org.apache.ratis.grpc.server.GrpcService;
+import org.apache.ratis.grpc.server.GrpcServices;
+import org.apache.ratis.grpc.server.GrpcServicesImpl;
import org.apache.ratis.protocol.ClientId;
import org.apache.ratis.rpc.SupportedRpcType;
import org.apache.ratis.server.RaftServer;
@@ -64,6 +65,8 @@ public class GrpcFactory implements ServerFactory,
ClientFactory {
return value;
}
+ private final GrpcServices.Customizer servicesCustomizer;
+
private final GrpcTlsConfig tlsConfig;
private final GrpcTlsConfig adminTlsConfig;
private final GrpcTlsConfig clientTlsConfig;
@@ -76,7 +79,7 @@ public class GrpcFactory implements ServerFactory,
ClientFactory {
}
public GrpcFactory(Parameters parameters) {
- this(
+ this(GrpcConfigKeys.Server.servicesCustomizer(parameters),
GrpcConfigKeys.TLS.conf(parameters),
GrpcConfigKeys.Admin.tlsConf(parameters),
GrpcConfigKeys.Client.tlsConf(parameters),
@@ -85,11 +88,14 @@ public class GrpcFactory implements ServerFactory,
ClientFactory {
}
public GrpcFactory(GrpcTlsConfig tlsConfig) {
- this(tlsConfig, null, null, null);
+ this(null, tlsConfig, null, null, null);
}
- private GrpcFactory(GrpcTlsConfig tlsConfig, GrpcTlsConfig adminTlsConfig,
+ private GrpcFactory(GrpcServices.Customizer servicesCustomizer,
+ GrpcTlsConfig tlsConfig, GrpcTlsConfig adminTlsConfig,
GrpcTlsConfig clientTlsConfig, GrpcTlsConfig serverTlsConfig) {
+ this.servicesCustomizer = servicesCustomizer;
+
this.tlsConfig = tlsConfig;
this.adminTlsConfig = adminTlsConfig;
this.clientTlsConfig = clientTlsConfig;
@@ -123,10 +129,11 @@ public class GrpcFactory implements ServerFactory,
ClientFactory {
}
@Override
- public GrpcService newRaftServerRpc(RaftServer server) {
+ public GrpcServices newRaftServerRpc(RaftServer server) {
checkPooledByteBufAllocatorUseCacheForAllThreads(LOG::info);
- return GrpcService.newBuilder()
+ return GrpcServicesImpl.newBuilder()
.setServer(server)
+ .setCustomizer(servicesCustomizer)
.setAdminTlsConfig(getAdminTlsConfig())
.setServerTlsConfig(getServerTlsConfig())
.setClientTlsConfig(getClientTlsConfig())
diff --git
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcLogAppender.java
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcLogAppender.java
index 18d4c62c6..45bc4c888 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcLogAppender.java
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcLogAppender.java
@@ -192,8 +192,8 @@ public class GrpcLogAppender extends LogAppenderBase {
}
@Override
- public GrpcService getServerRpc() {
- return (GrpcService)super.getServerRpc();
+ public GrpcServicesImpl getServerRpc() {
+ return (GrpcServicesImpl)super.getServerRpc();
}
private GrpcServerProtocolClient getClient() throws IOException {
@@ -428,7 +428,7 @@ public class GrpcLogAppender extends LogAppenderBase {
private void sendRequest(AppendEntriesRequest request,
AppendEntriesRequestProto proto) throws InterruptedIOException {
- CodeInjectionForTesting.execute(GrpcService.GRPC_SEND_SERVER_REQUEST,
+ CodeInjectionForTesting.execute(GrpcServicesImpl.GRPC_SEND_SERVER_REQUEST,
getServer().getId(), null, proto);
resetHeartbeatTrigger();
diff --git
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServices.java
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServices.java
new file mode 100644
index 000000000..663fd6d74
--- /dev/null
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServices.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ratis.grpc.server;
+
+import org.apache.ratis.server.RaftServerRpc;
+import org.apache.ratis.thirdparty.io.grpc.netty.NettyServerBuilder;
+
+import java.util.EnumSet;
+
+/** The gRPC services extending {@link RaftServerRpc}. */
+public interface GrpcServices extends RaftServerRpc {
+ /** The type of the services. */
+ enum Type {ADMIN, CLIENT, SERVER}
+
+ /**
+ * To customize the services.
+ * For example, add a custom service.
+ */
+ interface Customizer {
+ /** The default NOOP {@link Customizer}. */
+ class Default implements Customizer {
+ private static final Default INSTANCE = new Default();
+
+ @Override
+ public NettyServerBuilder customize(NettyServerBuilder builder,
EnumSet<GrpcServices.Type> types) {
+ return builder;
+ }
+ }
+
+ static Customizer getDefaultInstance() {
+ return Default.INSTANCE;
+ }
+
+ /**
+ * Customize the given builder for the given types.
+ *
+ * @return the customized builder.
+ */
+ NettyServerBuilder customize(NettyServerBuilder builder,
EnumSet<GrpcServices.Type> types);
+ }
+}
diff --git
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcService.java
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServicesImpl.java
similarity index 92%
rename from
ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcService.java
rename to
ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServicesImpl.java
index 510dfcaa2..d6f6a0c86 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcService.java
+++
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServicesImpl.java
@@ -21,6 +21,7 @@ import org.apache.ratis.conf.RaftProperties;
import org.apache.ratis.grpc.GrpcConfigKeys;
import org.apache.ratis.grpc.GrpcTlsConfig;
import org.apache.ratis.grpc.GrpcUtil;
+import org.apache.ratis.grpc.metrics.MessageMetrics;
import org.apache.ratis.grpc.metrics.ZeroCopyMetrics;
import org.apache.ratis.grpc.metrics.intercept.server.MetricServerInterceptor;
import org.apache.ratis.protocol.AdminAsynchronousProtocol;
@@ -51,6 +52,7 @@ import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.net.InetSocketAddress;
+import java.util.EnumSet;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
@@ -60,11 +62,12 @@ import java.util.function.Supplier;
import static
org.apache.ratis.thirdparty.io.netty.handler.ssl.SslProvider.OPENSSL;
/** A grpc implementation of {@link org.apache.ratis.server.RaftServerRpc}. */
-public final class GrpcService extends
RaftServerRpcWithProxy<GrpcServerProtocolClient,
- PeerProxyMap<GrpcServerProtocolClient>> {
- static final Logger LOG = LoggerFactory.getLogger(GrpcService.class);
+public final class GrpcServicesImpl
+ extends RaftServerRpcWithProxy<GrpcServerProtocolClient,
PeerProxyMap<GrpcServerProtocolClient>>
+ implements GrpcServices {
+ static final Logger LOG = LoggerFactory.getLogger(GrpcServicesImpl.class);
public static final String GRPC_SEND_SERVER_REQUEST =
- JavaUtils.getClassSimpleName(GrpcService.class) + ".sendRequest";
+ JavaUtils.getClassSimpleName(GrpcServicesImpl.class) + ".sendRequest";
class AsyncService implements RaftServerAsynchronousProtocol {
@@ -102,6 +105,7 @@ public final class GrpcService extends
RaftServerRpcWithProxy<GrpcServerProtocol
public static final class Builder {
private RaftServer server;
+ private Customizer customizer;
private String adminHost;
private int adminPort;
@@ -150,6 +154,11 @@ public final class GrpcService extends
RaftServerRpcWithProxy<GrpcServerProtocol
return this;
}
+ public Builder setCustomizer(Customizer customizer) {
+ this.customizer = customizer != null? customizer :
Customizer.getDefaultInstance();
+ return this;
+ }
+
private GrpcServerProtocolClient newGrpcServerProtocolClient(RaftPeer
target) {
return new GrpcServerProtocolClient(target,
flowControlWindow.getSizeInt(),
requestTimeoutDuration, serverTlsConfig, separateHeartbeatChannel);
@@ -177,6 +186,10 @@ public final class GrpcService extends
RaftServerRpcWithProxy<GrpcServerProtocol
JavaUtils.getClassSimpleName(getClass()) + "_" + serverPort);
}
+ Server buildServer(NettyServerBuilder builder, EnumSet<GrpcServices.Type>
types) {
+ return customizer.customize(builder, types).build();
+ }
+
private NettyServerBuilder newNettyServerBuilderForServer() {
return newNettyServerBuilder(serverHost, serverPort, serverTlsConfig);
}
@@ -223,21 +236,24 @@ public final class GrpcService extends
RaftServerRpcWithProxy<GrpcServerProtocol
}
Server newServer(GrpcClientProtocolService client, ZeroCopyMetrics
zeroCopyMetrics, ServerInterceptor interceptor) {
+ final EnumSet<GrpcServices.Type> types =
EnumSet.of(GrpcServices.Type.SERVER);
final NettyServerBuilder serverBuilder =
newNettyServerBuilderForServer();
final ServerServiceDefinition service =
newGrpcServerProtocolService(zeroCopyMetrics).bindServiceWithZeroCopy();
serverBuilder.addService(ServerInterceptors.intercept(service,
interceptor));
if (!separateAdminServer()) {
+ types.add(GrpcServices.Type.ADMIN);
addAdminService(serverBuilder, server, interceptor);
}
if (!separateClientServer()) {
+ types.add(GrpcServices.Type.CLIENT);
addClientService(serverBuilder, client, interceptor);
}
- return serverBuilder.build();
+ return buildServer(serverBuilder, types);
}
- public GrpcService build() {
- return new GrpcService(this);
+ public GrpcServicesImpl build() {
+ return new GrpcServicesImpl(this);
}
public Builder setAdminTlsConfig(GrpcTlsConfig config) {
@@ -273,11 +289,7 @@ public final class GrpcService extends
RaftServerRpcWithProxy<GrpcServerProtocol
private final MetricServerInterceptor serverInterceptor;
private final ZeroCopyMetrics zeroCopyMetrics = new ZeroCopyMetrics();
- public MetricServerInterceptor getServerInterceptor() {
- return serverInterceptor;
- }
-
- private GrpcService(Builder b) {
+ private GrpcServicesImpl(Builder b) {
super(b.server::getId, id -> new PeerProxyMap<>(id.toString(),
b::newGrpcServerProtocolClient));
this.executor = b.newExecutor();
@@ -291,7 +303,7 @@ public final class GrpcService extends
RaftServerRpcWithProxy<GrpcServerProtocol
if (b.separateAdminServer()) {
final NettyServerBuilder builder = b.newNettyServerBuilderForAdmin();
addAdminService(builder, b.server, serverInterceptor);
- final Server adminServer = builder.build();
+ final Server adminServer = b.buildServer(builder,
EnumSet.of(GrpcServices.Type.ADMIN));
servers.put(GrpcAdminProtocolService.class.getName(), adminServer);
adminServerAddressSupplier = newAddressSupplier(b.adminPort,
adminServer);
} else {
@@ -301,7 +313,7 @@ public final class GrpcService extends
RaftServerRpcWithProxy<GrpcServerProtocol
if (b.separateClientServer()) {
final NettyServerBuilder builder = b.newNettyServerBuilderForClient();
addClientService(builder, clientProtocolService, serverInterceptor);
- final Server clientServer = builder.build();
+ final Server clientServer = b.buildServer(builder,
EnumSet.of(GrpcServices.Type.CLIENT));
servers.put(GrpcClientProtocolService.class.getName(), clientServer);
clientServerAddressSupplier = newAddressSupplier(b.clientPort,
clientServer);
} else {
@@ -419,6 +431,11 @@ public final class GrpcService extends
RaftServerRpcWithProxy<GrpcServerProtocol
return getProxies().getProxy(target).startLeaderElection(request);
}
+ @VisibleForTesting
+ MessageMetrics getMessageMetrics() {
+ return serverInterceptor.getMetrics();
+ }
+
@VisibleForTesting
public ZeroCopyMetrics getZeroCopyMetrics() {
return zeroCopyMetrics;
diff --git
a/ratis-grpc/src/test/java/org/apache/ratis/grpc/MiniRaftClusterWithGrpc.java
b/ratis-grpc/src/test/java/org/apache/ratis/grpc/MiniRaftClusterWithGrpc.java
index fe12e29f1..1519298f3 100644
---
a/ratis-grpc/src/test/java/org/apache/ratis/grpc/MiniRaftClusterWithGrpc.java
+++
b/ratis-grpc/src/test/java/org/apache/ratis/grpc/MiniRaftClusterWithGrpc.java
@@ -22,7 +22,7 @@ import org.apache.ratis.RaftTestUtil;
import org.apache.ratis.conf.Parameters;
import org.apache.ratis.conf.RaftProperties;
import org.apache.ratis.grpc.metrics.ZeroCopyMetrics;
-import org.apache.ratis.grpc.server.GrpcService;
+import org.apache.ratis.grpc.server.GrpcServicesImpl;
import org.apache.ratis.protocol.RaftGroup;
import org.apache.ratis.protocol.RaftPeer;
import org.apache.ratis.protocol.RaftPeerId;
@@ -63,7 +63,7 @@ public class MiniRaftClusterWithGrpc extends
MiniRaftCluster.RpcBase {
}
public static final DelayLocalExecutionInjection
SEND_SERVER_REQUEST_INJECTION =
- new DelayLocalExecutionInjection(GrpcService.GRPC_SEND_SERVER_REQUEST);
+ new
DelayLocalExecutionInjection(GrpcServicesImpl.GRPC_SEND_SERVER_REQUEST);
public MiniRaftClusterWithGrpc(String[] ids, RaftProperties properties,
Parameters parameters) {
this(ids, new String[0], properties, parameters);
@@ -102,7 +102,7 @@ public class MiniRaftClusterWithGrpc extends
MiniRaftCluster.RpcBase {
getServers().forEach(server -> server.getGroupIds().forEach(id -> {
LOG.info("Checking {}-{}", server.getId(), id);
RaftServer.Division division = RaftServerTestUtil.getDivision(server,
id);
- GrpcService service = (GrpcService)
RaftServerTestUtil.getServerRpc(division);
+ final GrpcServicesImpl service = (GrpcServicesImpl)
RaftServerTestUtil.getServerRpc(division);
ZeroCopyMetrics zeroCopyMetrics = service.getZeroCopyMetrics();
Assert.assertEquals(0, zeroCopyMetrics.nonZeroCopyMessages());
Assert.assertEquals("Zero copy messages are not released, please check
logs to find leaks. ",
diff --git
a/ratis-test/src/test/java/org/apache/ratis/grpc/TestCustomGrpcServices.java
b/ratis-test/src/test/java/org/apache/ratis/grpc/TestCustomGrpcServices.java
new file mode 100644
index 000000000..13c4a59fb
--- /dev/null
+++ b/ratis-test/src/test/java/org/apache/ratis/grpc/TestCustomGrpcServices.java
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ratis.grpc;
+
+import org.apache.ratis.BaseTest;
+import org.apache.ratis.RaftTestUtil;
+import org.apache.ratis.client.RaftClient;
+import org.apache.ratis.conf.Parameters;
+import org.apache.ratis.conf.RaftProperties;
+import org.apache.ratis.grpc.server.GrpcServices;
+import org.apache.ratis.protocol.RaftClientReply;
+import org.apache.ratis.server.RaftServerRpc;
+import org.apache.ratis.test.proto.GreeterGrpc;
+import org.apache.ratis.test.proto.HelloReply;
+import org.apache.ratis.test.proto.HelloRequest;
+import org.apache.ratis.thirdparty.io.grpc.ManagedChannel;
+import org.apache.ratis.thirdparty.io.grpc.ManagedChannelBuilder;
+import org.apache.ratis.thirdparty.io.grpc.netty.NettyServerBuilder;
+import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver;
+import org.apache.ratis.util.IOUtils;
+import org.apache.ratis.util.NetUtils;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.EnumSet;
+import java.util.Objects;
+import java.util.Queue;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.TimeUnit;
+
+import static org.apache.ratis.RaftTestUtil.waitForLeader;
+
+public class TestCustomGrpcServices extends BaseTest {
+ /** Add two different greeter services for client and admin. */
+ class MyCustomizer implements GrpcServices.Customizer {
+ final GreeterImpl clientGreeter = new GreeterImpl("Hello");
+ final GreeterImpl adminGreeter = new GreeterImpl("Hi");
+
+ @Override
+ public NettyServerBuilder customize(NettyServerBuilder builder,
EnumSet<GrpcServices.Type> types) {
+ if (types.contains(GrpcServices.Type.CLIENT)) {
+ return builder.addService(clientGreeter);
+ }
+ if (types.contains(GrpcServices.Type.ADMIN)) {
+ return builder.addService(adminGreeter);
+ }
+ return builder;
+ }
+ }
+
+ class GreeterImpl extends GreeterGrpc.GreeterImplBase {
+ private final String prefix;
+
+ GreeterImpl(String prefix) {
+ this.prefix = prefix;
+ }
+
+ String toReply(String request) {
+ return prefix + " " + request;
+ }
+
+ @Override
+ public StreamObserver<HelloRequest> hello(StreamObserver<HelloReply>
responseObserver) {
+ return new StreamObserver<HelloRequest>() {
+ @Override
+ public void onNext(HelloRequest helloRequest) {
+ final String reply = toReply(helloRequest.getName());
+
responseObserver.onNext(HelloReply.newBuilder().setMessage(reply).build());
+ }
+
+ @Override
+ public void onError(Throwable throwable) {
+ LOG.error("onError", throwable);
+ }
+
+ @Override
+ public void onCompleted() {
+ responseObserver.onCompleted();
+ }
+ };
+ }
+ }
+
+ class GreeterClient implements Closeable {
+ private final ManagedChannel channel;
+ private final StreamObserver<HelloRequest> requestHandler;
+ private final Queue<CompletableFuture<String>> replies = new
ConcurrentLinkedQueue<>();
+
+ GreeterClient(int port) {
+ this.channel = ManagedChannelBuilder.forAddress(NetUtils.LOCALHOST, port)
+ .usePlaintext()
+ .build();
+
+ final StreamObserver<HelloReply> responseHandler = new
StreamObserver<HelloReply>() {
+ @Override
+ public void onNext(HelloReply helloReply) {
+ Objects.requireNonNull(replies.poll(), "queue is empty")
+ .complete(helloReply.getMessage());
+ }
+
+ @Override
+ public void onError(Throwable throwable) {
+ LOG.info("onError", throwable);
+ completeExceptionally(throwable);
+ }
+
+ @Override
+ public void onCompleted() {
+ LOG.info("onCompleted");
+ completeExceptionally(new IllegalStateException("onCompleted"));
+ }
+
+ void completeExceptionally(Throwable throwable) {
+ replies.forEach(f -> f.completeExceptionally(throwable));
+ replies.clear();
+ }
+ };
+ this.requestHandler =
GreeterGrpc.newStub(channel).hello(responseHandler);
+ }
+
+ @Override
+ public void close() throws IOException {
+ try {
+ /* After the request handler is cancelled, no more life-cycle hooks
are allowed,
+ * see {@link
org.apache.ratis.thirdparty.io.grpc.ClientCall.Listener#cancel(String,
Throwable)} */
+ // requestHandler.onCompleted();
+ channel.shutdown().awaitTermination(5, TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw IOUtils.toInterruptedIOException("Failed to close", e);
+ }
+ }
+
+ CompletableFuture<String> send(String name) {
+ LOG.info("send: {}", name);
+ final HelloRequest request =
HelloRequest.newBuilder().setName(name).build();
+ final CompletableFuture<String> f = new CompletableFuture<>();
+ try {
+ requestHandler.onNext(request);
+ replies.offer(f);
+ } catch (IllegalStateException e) {
+ // already closed
+ f.completeExceptionally(e);
+ }
+ return f.whenComplete((r, e) -> LOG.info("reply: {}", r));
+ }
+ }
+
+ @Test
+ public void testCustomServices() throws Exception {
+ final String[] ids = {"s0"};
+ final RaftProperties properties = new RaftProperties();
+
+ final Parameters parameters = new Parameters();
+ final MyCustomizer customizer = new MyCustomizer();
+ GrpcConfigKeys.Server.setServicesCustomizer(parameters, customizer);
+
+ try(MiniRaftClusterWithGrpc cluster = new MiniRaftClusterWithGrpc(ids,
properties, parameters)) {
+ cluster.start();
+ final RaftServerRpc server =
waitForLeader(cluster).getRaftServer().getServerRpc();
+
+ // test Raft service
+ try (RaftClient client = cluster.createClient()) {
+ final RaftClientReply reply = client.io().send(new
RaftTestUtil.SimpleMessage("abc"));
+ Assertions.assertTrue(reply.isSuccess());
+ }
+
+ // test custom client service
+ final int clientPort = server.getClientServerAddress().getPort();
+ try (GreeterClient client = new GreeterClient(clientPort)) {
+ sendAndAssertReply("world", client, customizer.clientGreeter);
+ }
+
+ // test custom admin service
+ final int adminPort = server.getAdminServerAddress().getPort();
+ try (GreeterClient admin = new GreeterClient(adminPort)) {
+ sendAndAssertReply("admin", admin, customizer.adminGreeter);
+ }
+ }
+ }
+
+ static void sendAndAssertReply(String name, GreeterClient client,
GreeterImpl greeter) {
+ final String computed = client.send(name).join();
+ final String expected = greeter.toReply(name);
+ Assertions.assertEquals(expected, computed);
+ }
+}
diff --git
a/ratis-test/src/test/java/org/apache/ratis/grpc/TestGrpcMessageMetrics.java
b/ratis-test/src/test/java/org/apache/ratis/grpc/server/TestGrpcMessageMetrics.java
similarity index 91%
rename from
ratis-test/src/test/java/org/apache/ratis/grpc/TestGrpcMessageMetrics.java
rename to
ratis-test/src/test/java/org/apache/ratis/grpc/server/TestGrpcMessageMetrics.java
index 812c691e2..8094069cf 100644
--- a/ratis-test/src/test/java/org/apache/ratis/grpc/TestGrpcMessageMetrics.java
+++
b/ratis-test/src/test/java/org/apache/ratis/grpc/server/TestGrpcMessageMetrics.java
@@ -15,13 +15,13 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.ratis.grpc;
+package org.apache.ratis.grpc.server;
import org.apache.ratis.BaseTest;
+import org.apache.ratis.grpc.MiniRaftClusterWithGrpc;
import org.apache.ratis.server.impl.MiniRaftCluster;
import org.apache.ratis.RaftTestUtil;
import org.apache.ratis.client.RaftClient;
-import org.apache.ratis.grpc.server.GrpcService;
import org.apache.ratis.metrics.impl.JvmMetrics;
import org.apache.ratis.metrics.RatisMetricRegistry;
import org.apache.ratis.protocol.RaftClientReply;
@@ -66,8 +66,8 @@ public class TestGrpcMessageMetrics extends BaseTest
static void assertMessageCount(RaftServer.Division server) {
String serverId = server.getId().toString();
- GrpcService service = (GrpcService)
RaftServerTestUtil.getServerRpc(server);
- RatisMetricRegistry registry =
service.getServerInterceptor().getMetrics().getRegistry();
+ final GrpcServicesImpl services = (GrpcServicesImpl)
RaftServerTestUtil.getServerRpc(server);
+ final RatisMetricRegistry registry =
services.getMessageMetrics().getRegistry();
String counter_prefix = serverId + "_" +
"ratis.grpc.RaftServerProtocolService";
Assertions.assertTrue(
registry.counter(counter_prefix + "_" + "requestVote" +
"_OK_completed_total").getCount() > 0);