[ https://issues.apache.org/jira/browse/FLINK-7738?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16556257#comment-16556257 ]
ASF GitHub Bot commented on FLINK-7738: --------------------------------------- EronWright closed pull request #4767: [FLINK-7738] [flip-6] Create WebSocket handler (server, client) URL: https://github.com/apache/flink/pull/4767 This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/flink-runtime-web/src/test/java/org/apache/flink/runtime/webmonitor/RedirectHandlerTest.java b/flink-runtime-web/src/test/java/org/apache/flink/runtime/webmonitor/RedirectHandlerTest.java index 4808781c7b8..cf465294f5a 100644 --- a/flink-runtime-web/src/test/java/org/apache/flink/runtime/webmonitor/RedirectHandlerTest.java +++ b/flink-runtime-web/src/test/java/org/apache/flink/runtime/webmonitor/RedirectHandlerTest.java @@ -29,17 +29,24 @@ import org.apache.flink.util.TestLogger; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandlerAdapter; +import org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.DefaultFullHttpRequest; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpMethod; import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponse; import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponseStatus; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpVersion; import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.router.KeepAliveWrite; import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.router.Routed; import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.router.Router; +import org.apache.flink.shaded.netty4.io.netty.util.ReferenceCountUtil; import org.junit.Assert; import org.junit.Test; import javax.annotation.Nonnull; +import java.util.HashMap; import java.util.Optional; import java.util.concurrent.CompletableFuture; @@ -137,6 +144,35 @@ public void testRedirectHandler() throws Exception { } } + /** + * Tests the approach of using the redirect handler as a standalone handler. + */ + @Test + public void testUserEvent() { + final String correctAddress = "foobar:21345"; + final CompletableFuture<String> localAddressFuture = CompletableFuture.completedFuture(correctAddress); + final Time timeout = Time.seconds(10L); + + final RestfulGateway localGateway = mock(RestfulGateway.class); + when(localGateway.requestRestAddress(any(Time.class))).thenReturn(CompletableFuture.completedFuture(correctAddress)); + final GatewayRetriever<RestfulGateway> gatewayRetriever = mock(GatewayRetriever.class); + when(gatewayRetriever.getNow()).thenReturn(Optional.of(localGateway)); + + final RedirectHandler<RestfulGateway> redirectHandler = new RedirectHandler<>( + localAddressFuture, + gatewayRetriever, + timeout); + final UserEventHandler eventHandler = new UserEventHandler(); + EmbeddedChannel channel = new EmbeddedChannel(redirectHandler, eventHandler); + + // write a (routed) HTTP request, then validate that a user event was propagated + DefaultFullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); + Routed routed = new Routed(null, false, request, "/", new HashMap<>(), new HashMap<>()); + channel.writeInbound(routed); + Assert.assertNotNull(eventHandler.gateway); + Assert.assertNotNull(eventHandler.routed); + } + private static class TestingHandler extends RedirectHandler<RestfulGateway> { protected TestingHandler( @@ -154,4 +190,25 @@ protected void respondAsLeader(ChannelHandlerContext channelHandlerContext, Rout } } + private static class UserEventHandler<T extends RestfulGateway> extends ChannelInboundHandlerAdapter { + + public volatile T gateway; + + public volatile Routed routed; + + @Override + @SuppressWarnings("unchecked") + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof RedirectHandler.GatewayRetrieved) { + gateway = ((RedirectHandler.GatewayRetrieved<T>) evt).getGateway(); + } + super.userEventTriggered(ctx, evt); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + routed = (Routed) msg; + ReferenceCountUtil.release(msg); + } + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestClient.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestClient.java index 3fcb85c5096..9c1de986ba0 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestClient.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestClient.java @@ -20,6 +20,7 @@ import org.apache.flink.api.common.time.Time; import org.apache.flink.configuration.ConfigConstants; +import org.apache.flink.runtime.rest.handler.JsonWebSocketMessageCodec; import org.apache.flink.runtime.rest.handler.PipelineErrorHandler; import org.apache.flink.runtime.rest.messages.EmptyMessageParameters; import org.apache.flink.runtime.rest.messages.EmptyRequestBody; @@ -28,9 +29,12 @@ import org.apache.flink.runtime.rest.messages.MessageParameters; import org.apache.flink.runtime.rest.messages.RequestBody; import org.apache.flink.runtime.rest.messages.ResponseBody; +import org.apache.flink.runtime.rest.messages.WebSocketSpecification; import org.apache.flink.runtime.rest.util.RestClientException; import org.apache.flink.runtime.rest.util.RestConstants; import org.apache.flink.runtime.rest.util.RestMapperUtils; +import org.apache.flink.runtime.rest.websocket.WebSocket; +import org.apache.flink.runtime.rest.websocket.WebSocketListener; import org.apache.flink.util.FlinkRuntimeException; import org.apache.flink.util.Preconditions; @@ -38,6 +42,7 @@ import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBufInputStream; import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled; +import org.apache.flink.shaded.netty4.io.netty.channel.Channel; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFuture; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInitializer; @@ -46,6 +51,7 @@ import org.apache.flink.shaded.netty4.io.netty.channel.socket.SocketChannel; import org.apache.flink.shaded.netty4.io.netty.channel.socket.nio.NioSocketChannel; import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.DefaultFullHttpRequest; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.DefaultHttpHeaders; import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.FullHttpRequest; import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.FullHttpResponse; import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpClientCodec; @@ -54,7 +60,10 @@ import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponse; import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponseStatus; import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpVersion; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.websocketx.WebSocketVersion; import org.apache.flink.shaded.netty4.io.netty.handler.ssl.SslHandler; +import org.apache.flink.shaded.netty4.io.netty.util.ReferenceCountUtil; import org.apache.flink.shaded.netty4.io.netty.util.concurrent.DefaultThreadFactory; import com.fasterxml.jackson.core.JsonParseException; @@ -69,7 +78,11 @@ import java.io.IOException; import java.io.InputStream; import java.io.StringWriter; +import java.net.URI; +import java.util.Arrays; +import java.util.List; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; @@ -81,38 +94,23 @@ private static final ObjectMapper objectMapper = RestMapperUtils.getStrictObjectMapper(); + private final RestClientConfiguration configuration; + // used to open connections to a rest server endpoint private final Executor executor; private Bootstrap bootstrap; public RestClient(RestClientConfiguration configuration, Executor executor) { - Preconditions.checkNotNull(configuration); + this.configuration = Preconditions.checkNotNull(configuration); this.executor = Preconditions.checkNotNull(executor); - SSLEngine sslEngine = configuration.getSslEngine(); - ChannelInitializer<SocketChannel> initializer = new ChannelInitializer<SocketChannel>() { - @Override - protected void initChannel(SocketChannel socketChannel) throws Exception { - // SSL should be the first handler in the pipeline - if (sslEngine != null) { - socketChannel.pipeline().addLast("ssl", new SslHandler(sslEngine)); - } - - socketChannel.pipeline() - .addLast(new HttpClientCodec()) - .addLast(new HttpObjectAggregator(1024 * 1024)) - .addLast(new ClientHandler()) - .addLast(new PipelineErrorHandler(LOG)); - } - }; NioEventLoopGroup group = new NioEventLoopGroup(1, new DefaultThreadFactory("flink-rest-client-netty")); bootstrap = new Bootstrap(); bootstrap .group(group) - .channel(NioSocketChannel.class) - .handler(initializer); + .channel(NioSocketChannel.class); LOG.info("Rest client endpoint started."); } @@ -175,7 +173,17 @@ public void shutdown(Time timeout) { } private <P extends ResponseBody> CompletableFuture<P> submitRequest(String targetAddress, int targetPort, FullHttpRequest httpRequest, Class<P> responseClass) { - return CompletableFuture.supplyAsync(() -> bootstrap.connect(targetAddress, targetPort), executor) + Bootstrap bootstrap1 = bootstrap.clone().handler(new ClientBootstrap() { + @Override + protected void initChannel(SocketChannel channel) throws Exception { + super.initChannel(channel); + channel.pipeline() + .addLast(new ClientHandler()) + .addLast(new PipelineErrorHandler(LOG)); + } + }); + + return CompletableFuture.supplyAsync(() -> bootstrap1.connect(targetAddress, targetPort), executor) .thenApply((channel) -> { try { return channel.sync(); @@ -221,6 +229,21 @@ public void shutdown(Time timeout) { return responseFuture; } + private class ClientBootstrap extends ChannelInitializer<SocketChannel> { + @Override + protected void initChannel(SocketChannel channel) throws Exception { + SSLEngine sslEngine = configuration.getSslEngine(); + + // SSL should be the first handler in the pipeline + if (sslEngine != null) { + channel.pipeline().addLast("ssl", new SslHandler(sslEngine)); + } + channel.pipeline() + .addLast(new HttpClientCodec()) + .addLast(new HttpObjectAggregator(1024 * 1024)); + } + } + private static class ClientHandler extends SimpleChannelInboundHandler<Object> { private final CompletableFuture<JsonResponse> jsonFuture = new CompletableFuture<>(); @@ -300,4 +323,116 @@ public HttpResponseStatus getHttpResponseStatus() { return httpResponseStatus; } } + + //M messageHeaders + @SuppressWarnings("unchecked") + public <M extends WebSocketSpecification<U, O, I>, U extends MessageParameters, I extends ResponseBody, O extends RequestBody> CompletableFuture<WebSocket<I, O>> sendWebSocketRequest(String targetAddress, int targetPort, M spec, U messageParameters, WebSocketListener<I>... listeners) throws IOException { + Preconditions.checkNotNull(targetAddress); + Preconditions.checkArgument(0 <= targetPort && targetPort < 65536, "The target port " + targetPort + " is not in the range (0, 65536]."); + Preconditions.checkNotNull(spec); + Preconditions.checkNotNull(messageParameters); + Preconditions.checkState(messageParameters.isResolved(), "Message parameters were not resolved."); + + String targetUrl = MessageParameters.resolveUrl(spec.getTargetRestEndpointURL(), messageParameters); + URI webSocketURL = URI.create("ws://" + targetAddress + ":" + targetPort).resolve(targetUrl); + LOG.debug("Sending WebSocket request to {}", webSocketURL); + + final HttpHeaders headers = new DefaultHttpHeaders() + .add(HttpHeaders.Names.CONTENT_TYPE, RestConstants.REST_CONTENT_TYPE); + + Bootstrap bootstrap1 = bootstrap.clone().handler(new ClientBootstrap() { + @Override + protected void initChannel(SocketChannel channel) throws Exception { + super.initChannel(channel); + channel.pipeline() + .addLast(new WebSocketClientProtocolHandler(webSocketURL, WebSocketVersion.V13, spec.getSubprotocol(), false, headers, 65535)) + .addLast(new JsonWebSocketMessageCodec<>(spec.getServerClass(), spec.getClientClass())) + .addLast(new WsResponseHandler<>(channel, spec.getServerClass(), spec.getClientClass(), listeners)); + } + }); + + return CompletableFuture.supplyAsync(() -> bootstrap1.connect(targetAddress, targetPort), executor) + .thenApply((channel) -> { + try { + return channel.sync(); + } catch (InterruptedException e) { + throw new FlinkRuntimeException(e); + } + }) + .thenApply((ChannelFuture::channel)) + .thenCompose(channel -> { + WsResponseHandler<I, O> handler = channel.pipeline().get(WsResponseHandler.class); + return handler.getWebSocketFuture(); + }); + } + + private static class WsResponseHandler<I extends ResponseBody, O extends RequestBody> extends SimpleChannelInboundHandler<I> implements WebSocket<I, O> { + + private final Channel channel; + private final List<WebSocketListener<I>> listeners = new CopyOnWriteArrayList<>(); + + private final CompletableFuture<WebSocket<I, O>> webSocketFuture = new CompletableFuture<>(); + + CompletableFuture<WebSocket<I, O>> getWebSocketFuture() { + return webSocketFuture; + } + + public WsResponseHandler(Channel channel, Class<I> inboundClass, Class<O> outboundClass, WebSocketListener<I>[] listeners) { + super(inboundClass); + this.channel = channel; + this.listeners.addAll(Arrays.asList(listeners)); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + LOG.warn("WebSocket exception caught", cause); + webSocketFuture.completeExceptionally(cause); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent) { + WebSocketClientProtocolHandler.ClientHandshakeStateEvent wsevt = (WebSocketClientProtocolHandler.ClientHandshakeStateEvent) evt; + switch (wsevt) { + case HANDSHAKE_ISSUED: + LOG.debug("WebSocket handshake initiated"); + break; + case HANDSHAKE_COMPLETE: + LOG.debug("WebSocket handshake completed"); + webSocketFuture.complete(this); + break; + } + } + else { + super.userEventTriggered(ctx, evt); + } + } + + @Override + protected void channelRead0(ChannelHandlerContext channelHandlerContext, I msg) throws Exception { + for (WebSocketListener<I> listener : listeners) { + listener.onEvent(msg); + } + } + + @Override + public void addListener(WebSocketListener<I> listener) { + listeners.add(listener); + } + + @Override + public ChannelFuture send(O message) { + try { + return channel.writeAndFlush(message); + } + finally { + ReferenceCountUtil.release(message); + } + } + + @Override + public ChannelFuture close() { + return channel.close(); + } + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/AbstractWebSocketHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/AbstractWebSocketHandler.java new file mode 100644 index 00000000000..3a21d8a77d9 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/AbstractWebSocketHandler.java @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest.handler; + +import org.apache.flink.api.common.time.Time; +import org.apache.flink.runtime.concurrent.FutureUtils; +import org.apache.flink.runtime.rest.handler.util.HandlerUtils; +import org.apache.flink.runtime.rest.messages.ErrorResponseBody; +import org.apache.flink.runtime.rest.messages.MessageParameters; +import org.apache.flink.runtime.rest.messages.RequestBody; +import org.apache.flink.runtime.rest.messages.ResponseBody; +import org.apache.flink.runtime.rest.messages.WebSocketSpecification; +import org.apache.flink.runtime.webmonitor.RestfulGateway; +import org.apache.flink.runtime.webmonitor.retriever.GatewayRetriever; +import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.util.Preconditions; + +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandler; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandlerAdapter; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpRequest; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponseStatus; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.router.Routed; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; +import org.apache.flink.shaded.netty4.io.netty.util.AttributeKey; +import org.apache.flink.shaded.netty4.io.netty.util.ReferenceCountUtil; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; + +import java.util.concurrent.CompletableFuture; + +/** + * A channel handler for WebSocket resources. + * + * <p>This handler handles handshaking and ongoing messaging with a WebSocket client, + * based on a {@link WebSocketSpecification} that describes the REST resource location, + * parameter type, and message inbound/outbound types. Messages are automatically + * encoded from (and decoded to) JSON text. + * + * <p>Subclasses should override the following methods to extend the respective phases. + * <ol> + * <li>{@code handshakeInitiated} - occurs upon receipt of a handshake request from an HTTP client. Useful for parameter validation.</li> + * <li>{@code handshakeCompleted} - occurs upon successful completion; WebSocket is ready for I/O.</li> + * <li>{@code messageReceived} - occurs when a WebSocket message is received on the channel.</li> + * </ol> + * + * <p>The handler supports gateway availability announcements. + * + * @param <T> The gateway type. + * @param <M> The REST parameter type. + * @param <I> The inbound message type. + * @param <O> The outbound message type. + */ +public abstract class AbstractWebSocketHandler<T extends RestfulGateway, M extends MessageParameters, I extends RequestBody, O extends ResponseBody> extends ChannelInboundHandlerAdapter { + + protected final Logger log = LoggerFactory.getLogger(getClass()); + + private final RedirectHandler redirectHandler; + + private final AttributeKey<T> gatewayAttr; + + private final WebSocketSpecification<M, I, O> specification; + + private final ChannelHandler messageCodec; + + private final AttributeKey<M> parametersAttr; + + /** + * Creates a new handler. + */ + public AbstractWebSocketHandler( + @Nonnull CompletableFuture<String> localAddressFuture, + @Nonnull GatewayRetriever<? extends T> leaderRetriever, + @Nonnull Time timeout, + @Nonnull WebSocketSpecification<M, I, O> specification) { + this.redirectHandler = new RedirectHandler<>(localAddressFuture, leaderRetriever, timeout); + this.gatewayAttr = AttributeKey.valueOf("gateway"); + this.specification = specification; + this.messageCodec = new JsonWebSocketMessageCodec<>(specification.getClientClass(), specification.getServerClass()); + this.parametersAttr = AttributeKey.valueOf("parameters"); + } + + /** + * Sets the gateway associated with the channel. + */ + private void setGateway(ChannelHandlerContext ctx, T gateway) { + ctx.attr(gatewayAttr).set(gateway); + } + + /** + * Returns the gateway associated with the channel. + */ + public T getGateway(ChannelHandlerContext ctx) { + T t = ctx.attr(gatewayAttr).get(); + Preconditions.checkState(t != null, "Gateway is not available."); + return t; + } + + /** + * Sets the resource parameters associated with the channel. + * + * <p>The parameters are established by the WebSocket handshake request. + */ + private void setMessageParameters(ChannelHandlerContext ctx, M parameters) { + ctx.attr(parametersAttr).set(parameters); + } + + /** + * Returns the resource parameters associated with the channel. + */ + public M getMessageParameters(ChannelHandlerContext ctx) { + M o = ctx.attr(parametersAttr).get(); + Preconditions.checkState(o != null, "Message parameters are not available."); + return o; + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + if (ctx.pipeline().get(RedirectHandler.class.getName()) == null) { + ctx.pipeline().addBefore(ctx.name(), RedirectHandler.class.getName(), redirectHandler); + } + } + + @Override + @SuppressWarnings("unchecked") + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof RedirectHandler.GatewayRetrieved) { + T gateway = ((RedirectHandler.GatewayRetrieved<T>) evt).getGateway(); + setGateway(ctx, gateway); + log.debug("Gateway retrieved: {}", gateway); + } + else if (evt instanceof WebSocketServerProtocolHandler.ServerHandshakeStateEvent) { + WebSocketServerProtocolHandler.ServerHandshakeStateEvent handshakeEvent = (WebSocketServerProtocolHandler.ServerHandshakeStateEvent) evt; + if (handshakeEvent == WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE) { + log.debug("Handshake completed with client IP: {}", ctx.channel().remoteAddress()); + M parameters = getMessageParameters(ctx); + handshakeCompleted(ctx, parameters); + } + } + + super.userEventTriggered(ctx, evt); + } + + @Override + @SuppressWarnings("unchecked") + public void channelRead(ChannelHandlerContext ctx, Object o) throws Exception { + if (specification.getClientClass().isAssignableFrom(o.getClass())) { + // process an inbound message + M parameters = getMessageParameters(ctx); + try { + messageReceived(ctx, parameters, (I) o); + } + finally { + ReferenceCountUtil.release(o); + } + return; + } + + if (!(o instanceof Routed)) { + // a foreign message + ctx.fireChannelRead(o); + return; + } + + // process an inbound HTTP request + Routed request = (Routed) o; + + // parse the REST request parameters + M messageParameters = specification.getUnresolvedMessageParameters(); + try { + messageParameters.resolveParameters(request.pathParams(), request.queryParams()); + if (!messageParameters.isResolved()) { + throw new IllegalArgumentException("One or more mandatory parameters is missing."); + } + } + catch (IllegalArgumentException e) { + HandlerUtils.sendErrorResponse( + ctx, + request.request(), + new ErrorResponseBody(String.format("Bad request, could not parse parameters: %s", e.getMessage())), + HttpResponseStatus.BAD_REQUEST); + ReferenceCountUtil.release(request); + return; + } + setMessageParameters(ctx, messageParameters); + + // validate the inbound handshake request with the subclass + CompletableFuture<Void> handshakeReady; + try { + handshakeReady = handshakeInitiated(ctx, messageParameters); + } catch (Exception e) { + handshakeReady = FutureUtils.completedExceptionally(e); + } + handshakeReady.whenCompleteAsync((Void v, Throwable throwable) -> { + try { + if (throwable != null) { + Throwable error = ExceptionUtils.stripCompletionException(throwable); + if (error instanceof RestHandlerException) { + final RestHandlerException rhe = (RestHandlerException) error; + log.error("Exception occurred in REST handler.", error); + HandlerUtils.sendErrorResponse(ctx, request.request(), new ErrorResponseBody(rhe.getMessage()), rhe.getHttpResponseStatus()); + } else { + log.error("Implementation error: Unhandled exception.", error); + HandlerUtils.sendErrorResponse(ctx, request.request(), new ErrorResponseBody("Internal server error."), HttpResponseStatus.INTERNAL_SERVER_ERROR); + } + } else { + upgradeToWebSocket(ctx, request); + } + } + finally { + ReferenceCountUtil.release(request); + } + }, ctx.executor()); + } + + private void upgradeToWebSocket(final ChannelHandlerContext ctx, Routed msg) { + + // store the context of the handler that precedes the current handler, + // to use that context later to forward the HTTP request to the WebSocket protocol handler + String before = ctx.pipeline().names().get(ctx.pipeline().names().indexOf(ctx.name()) - 1); + ChannelHandlerContext beforeCtx = ctx.pipeline().context(before); + + // inject the websocket protocol handler into this channel, to be active + // until the channel is closed. note that the handshake may or may not complete synchronously. + ctx.pipeline().addBefore(ctx.name(), WebSocketServerProtocolHandler.class.getName(), + new WebSocketServerProtocolHandler(msg.path(), specification.getSubprotocol())); + + // inject the message codec + ctx.pipeline().addBefore(ctx.name(), messageCodec.getClass().getName(), messageCodec); + + log.debug("Upgraded channel with WS protocol handler and message codec."); + + // forward the message to the installed protocol handler to initiate handshaking + HttpRequest request = msg.request(); + ReferenceCountUtil.retain(request); + beforeCtx.fireChannelRead(request); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + log.error("WebSocket channel error; closing the channel.", cause); + ctx.close(); + } + + /** + * Handles a client handshake request to open a WebSocket resource. Returns a {@link CompletableFuture} to complete handshaking. + * + * <p>Implementations may decide whether to throw {@link RestHandlerException}s or fail the returned + * {@link CompletableFuture} with a {@link RestHandlerException}. + * + * <p>Failing the future with another exception type or throwing unchecked exceptions is regarded as an + * implementation error as it does not allow us to provide a meaningful HTTP status code. In this case a + * {@link HttpResponseStatus#INTERNAL_SERVER_ERROR} will be returned. + * + * @param parameters the REST parameters + + * @return future indicating completion of handshake pre-processing. + * @throws RestHandlerException to produce a pre-formatted HTTP error response. + */ + protected CompletableFuture<Void> handshakeInitiated(ChannelHandlerContext ctx, M parameters) throws Exception { + return CompletableFuture.completedFuture(null); + } + + /** + * Invoked when the current channel has completed the handshaking to establish a WebSocket connection. + * + * @param ctx the channel handler context + * @param parameters the REST parameters + * @throws Exception if processing failed. + */ + protected void handshakeCompleted(ChannelHandlerContext ctx, M parameters) throws Exception { + } + + /** + * Invoked when the current channel has received a WebSocket message. + * + * <p>The message object is automatically released after this method is called. + * + * @param ctx the channel handler context + * @param parameters the REST parameters + * @param msg the message received + * @throws Exception if the message could not be processed. + */ + protected abstract void messageReceived(ChannelHandlerContext ctx, M parameters, I msg) throws Exception; +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/HandlerRequest.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/HandlerRequest.java index aacf0a22a16..f92282956b2 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/HandlerRequest.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/HandlerRequest.java @@ -25,10 +25,10 @@ import org.apache.flink.util.Preconditions; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.StringJoiner; +import java.util.function.Function; +import java.util.stream.Collectors; /** * Simple container for the request to a handler, that contains the {@link RequestBody} and path/query parameters. @@ -39,52 +39,29 @@ public class HandlerRequest<R extends RequestBody, M extends MessageParameters> { private final R requestBody; - private final Map<Class<? extends MessagePathParameter<?>>, MessagePathParameter<?>> pathParameters = new HashMap<>(2); - private final Map<Class<? extends MessageQueryParameter<?>>, MessageQueryParameter<?>> queryParameters = new HashMap<>(2); + private final Map<Class<? extends MessagePathParameter<?>>, MessagePathParameter<?>> pathParameters; + private final Map<Class<? extends MessageQueryParameter<?>>, MessageQueryParameter<?>> queryParameters; public HandlerRequest(R requestBody, M messageParameters) throws HandlerRequestException { this(requestBody, messageParameters, Collections.emptyMap(), Collections.emptyMap()); } + @SuppressWarnings("unchecked") public HandlerRequest(R requestBody, M messageParameters, Map<String, String> receivedPathParameters, Map<String, List<String>> receivedQueryParameters) throws HandlerRequestException { this.requestBody = Preconditions.checkNotNull(requestBody); Preconditions.checkNotNull(messageParameters); Preconditions.checkNotNull(receivedQueryParameters); Preconditions.checkNotNull(receivedPathParameters); - for (MessagePathParameter<?> pathParameter : messageParameters.getPathParameters()) { - String value = receivedPathParameters.get(pathParameter.getKey()); - if (value != null) { - try { - pathParameter.resolveFromString(value); - } catch (Exception e) { - throw new HandlerRequestException("Cannot resolve path parameter (" + pathParameter.getKey() + ") from value \"" + value + "\"."); - } - - @SuppressWarnings("unchecked") - Class<? extends MessagePathParameter<?>> clazz = (Class<? extends MessagePathParameter<?>>) pathParameter.getClass(); - pathParameters.put(clazz, pathParameter); - } + try { + messageParameters.resolveParameters(receivedPathParameters, receivedQueryParameters); } - - for (MessageQueryParameter<?> queryParameter : messageParameters.getQueryParameters()) { - List<String> values = receivedQueryParameters.get(queryParameter.getKey()); - if (values != null && !values.isEmpty()) { - StringJoiner joiner = new StringJoiner(","); - values.forEach(joiner::add); - - try { - queryParameter.resolveFromString(joiner.toString()); - } catch (Exception e) { - throw new HandlerRequestException("Cannot resolve query parameter (" + queryParameter.getKey() + ") from value \"" + joiner + "\"."); - } - - @SuppressWarnings("unchecked") - Class<? extends MessageQueryParameter<?>> clazz = (Class<? extends MessageQueryParameter<?>>) queryParameter.getClass(); - queryParameters.put(clazz, queryParameter); - } - + catch (IllegalArgumentException e) { + throw new HandlerRequestException("Unable to resolve the request parameters: " + e.getMessage()); } + + pathParameters = messageParameters.getPathParameters().stream().collect(Collectors.toMap(p -> (Class<? extends MessagePathParameter<?>>) p.getClass(), Function.identity())); + queryParameters = messageParameters.getQueryParameters().stream().collect(Collectors.toMap(p -> (Class<? extends MessageQueryParameter<?>>) p.getClass(), Function.identity())); } /** diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/JsonWebSocketMessageCodec.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/JsonWebSocketMessageCodec.java new file mode 100644 index 00000000000..48127bd88fe --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/JsonWebSocketMessageCodec.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest.handler; + +import org.apache.flink.runtime.rest.util.RestMapperUtils; + +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandler; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; +import org.apache.flink.shaded.netty4.io.netty.channel.CombinedChannelDuplexHandler; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.DecoderException; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.EncoderException; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.MessageToMessageDecoder; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.MessageToMessageEncoder; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.websocketx.TextWebSocketFrame; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.ObjectReader; +import com.fasterxml.jackson.databind.ObjectWriter; + +import java.io.StringReader; +import java.io.StringWriter; +import java.util.List; + +/** + * A codec for JSON-encoded WebSocket messages. + * + * @param <I> the inbound message type (converted from JSON). + * @param <O> the outbound message type (converted to JSON). + */ +@ChannelHandler.Sharable +public class JsonWebSocketMessageCodec<I, O> extends CombinedChannelDuplexHandler<JsonWebSocketMessageCodec.WebSocketMessageDecoder<I>, JsonWebSocketMessageCodec.WebSocketMessageEncoder<O>> { + + private static final ObjectMapper mapper = RestMapperUtils.getStrictObjectMapper(); + + public JsonWebSocketMessageCodec(Class<I> inboundMessageClass, Class<O> outboundMessageClass) { + super(new WebSocketMessageDecoder<>(inboundMessageClass), new WebSocketMessageEncoder<>(outboundMessageClass)); + } + + @ChannelHandler.Sharable + static class WebSocketMessageDecoder<I> extends MessageToMessageDecoder<TextWebSocketFrame> { + + private final ObjectReader reader; + + public WebSocketMessageDecoder(Class<I> messageClass) { + reader = mapper.readerFor(messageClass); + } + + @Override + protected void decode(ChannelHandlerContext channelHandlerContext, TextWebSocketFrame frame, List<Object> list) throws Exception { + try { + try (StringReader sr = new StringReader(frame.text())) { + I i = reader.readValue(sr); + list.add(i); + } + } + catch (Exception e) { + throw new DecoderException("Unable to decode the WebSocket frame as a JSON object", e); + } + } + } + + @ChannelHandler.Sharable + static class WebSocketMessageEncoder<O> extends MessageToMessageEncoder<O> { + + private final ObjectWriter writer; + + protected WebSocketMessageEncoder(Class<O> messageClass) { + super(messageClass); + writer = mapper.writerFor(messageClass); + } + + @Override + protected void encode(ChannelHandlerContext ctx, O o, List<Object> list) throws Exception { + try { + StringWriter sw = new StringWriter(); + writer.writeValue(sw, o); + list.add(new TextWebSocketFrame(sw.toString())); + } + catch (Exception e) { + throw new EncoderException("Unable to encode the JSON object as a WebSocket frame", e); + } + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/RedirectHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/RedirectHandler.java index 40b67767737..07a79e11902 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/RedirectHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/RedirectHandler.java @@ -49,10 +49,16 @@ * {@link SimpleChannelInboundHandler} which encapsulates the redirection logic for the * REST endpoints. * + * <p>This handler supports two modes of use. + * <ol> + * <li>Inheritance - subclasses override {@code respondAsLeader}.</li> + * <li>Composition - upstream handlers receive a user event containing the gateway, see {@link GatewayRetrieved}.</li> + * </ol> + * * @param <T> type of the leader to retrieve */ @ChannelHandler.Sharable -public abstract class RedirectHandler<T extends RestfulGateway> extends SimpleChannelInboundHandler<Routed> { +public class RedirectHandler<T extends RestfulGateway> extends SimpleChannelInboundHandler<Routed> { protected final Logger logger = LoggerFactory.getLogger(getClass()); @@ -64,7 +70,7 @@ private String localAddress; - protected RedirectHandler( + public RedirectHandler( @Nonnull CompletableFuture<String> localAddressFuture, @Nonnull GatewayRetriever<? extends T> leaderRetriever, @Nonnull Time timeout) { @@ -172,5 +178,39 @@ protected void channelRead0( } } - protected abstract void respondAsLeader(ChannelHandlerContext channelHandlerContext, Routed routed, T gateway) throws Exception; + /** + * Responds to an HTTP request in combination with the leader gateway. + * + * <p>The default behavior is to announce the leader gateway with a user event, + * and then to forward the HTTP request to the next handler. + * + * @param ctx the channel handler context + * @param routed the HTTP request + * @param gateway the leader gateway + */ + protected void respondAsLeader(ChannelHandlerContext ctx, Routed routed, T gateway) throws Exception { + + // announce the gateway to upstream handlers + ctx.fireUserEventTriggered(new GatewayRetrieved<>(gateway)); + + // propagate the HTTP request + ReferenceCountUtil.retain(routed); + ctx.fireChannelRead(routed); + } + + /** + * A gateway retrieval event. + * @param <T> the gateway type + */ + public static class GatewayRetrieved<T> { + private final T gateway; + + public GatewayRetrieved(T gateway) { + this.gateway = gateway; + } + + public T getGateway() { + return gateway; + } + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/MessageParameters.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/MessageParameters.java index b19b12e7206..b91c4510e83 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/MessageParameters.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/MessageParameters.java @@ -21,9 +21,15 @@ import org.apache.flink.util.Preconditions; import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.StringJoiner; /** * This class defines the path/query {@link MessageParameter}s that can be used for a request. + * + * <p>An instance of {@link MessageParameters} is mutable. */ public abstract class MessageParameters { @@ -61,7 +67,7 @@ public final boolean isResolved() { * <p>Unresolved optional parameters will be ignored. * * @param genericUrl URL to resolve - * @param parameters message parameters parameters + * @param parameters message parameters * @return resolved url, e.g "/jobs/1234?state=running" * @throws IllegalStateException if any mandatory parameter was not resolved */ @@ -100,4 +106,74 @@ public static String resolveUrl(String genericUrl, MessageParameters parameters) return path.toString(); } + + /** + * Resolves the message parameters using the given path and query parameter values. + * + * <p>This method updates the state of the parameters of this message. + * + * @param receivedPathParameters the received path parameters. + * @param receivedQueryParameters the received query parameters. + * @throws IllegalArgumentException if a parameter value cannot be processed. + */ + public void resolveParameters(Map<String, String> receivedPathParameters, Map<String, List<String>> receivedQueryParameters) { + for (MessagePathParameter<?> pathParameter : getPathParameters()) { + String value = receivedPathParameters.get(pathParameter.getKey()); + if (value != null) { + try { + pathParameter.resolveFromString(value); + } catch (Exception e) { + throw new IllegalArgumentException("Cannot resolve path parameter (" + pathParameter.getKey() + ") from value \"" + value + "\"."); + } + } + } + + for (MessageQueryParameter<?> queryParameter : getQueryParameters()) { + List<String> values = receivedQueryParameters.get(queryParameter.getKey()); + if (values != null && !values.isEmpty()) { + StringJoiner joiner = new StringJoiner(","); + values.forEach(joiner::add); + + try { + queryParameter.resolveFromString(joiner.toString()); + } catch (Exception e) { + throw new IllegalArgumentException("Cannot resolve query parameter (" + queryParameter.getKey() + ") from value \"" + joiner + "\"."); + } + } + } + } + + + /** + * Returns the value of the {@link MessagePathParameter} for the given class. + * + * @param parameterClass class of the parameter + * @param <X> the value type that the parameter contains + * @param <PP> type of the path parameter + * @return path parameter value for the given class + * @throws IllegalStateException if no value is defined for the given parameter class + */ + @SuppressWarnings("unchecked") + public <X, PP extends MessagePathParameter<X>> X getPathParameter(Class<PP> parameterClass) { + return getPathParameters().stream() + .filter(p -> parameterClass.equals(p.getClass())) + .map(p -> ((PP) p).getValue()) + .findAny().orElseThrow(() -> new IllegalStateException("No parameter could be found for the given class.")); + } + + /** + * Returns the value of the {@link MessageQueryParameter} for the given class. + * + * @param parameterClass class of the parameter + * @param <X> the value type that the parameter contains + * @param <QP> type of the query parameter + * @return query parameter value for the given class, or an empty list if no parameter value exists for the given class + */ + @SuppressWarnings("unchecked") + public <X, QP extends MessageQueryParameter<X>> List<X> getQueryParameter(Class<QP> parameterClass) { + return getQueryParameters().stream() + .filter(p -> parameterClass.equals(p.getClass())) + .map(p -> ((QP) p).getValue()) + .findAny().orElse(Collections.emptyList()); + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/RequestBody.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/RequestBody.java index ca55b17532b..7e2ef568448 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/RequestBody.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/RequestBody.java @@ -19,7 +19,7 @@ package org.apache.flink.runtime.rest.messages; /** - * Marker interface for all requests of the REST API. This class represents the http body of a request. + * Marker interface for all requests of the REST API. This class represents the body of an HTTP request or WebSocket message. * * <p>Subclass instances are converted to JSON using jackson-databind. Subclasses must have a constructor that accepts * all fields of the JSON request, that should be annotated with {@code @JsonCreator}. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/ResponseBody.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/ResponseBody.java index d4e94d1d6ab..6d183903637 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/ResponseBody.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/ResponseBody.java @@ -19,7 +19,7 @@ package org.apache.flink.runtime.rest.messages; /** - * Marker interface for all responses of the REST API. This class represents the http body of a response. + * Marker interface for all responses of the REST API. This class represents the body of an HTTP response or WebSocket message. * * <p>Subclass instances are converted to JSON using jackson-databind. Subclasses must have a constructor that accepts * all fields of the JSON response, that should be annotated with {@code @JsonCreator}. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/WebSocketSpecification.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/WebSocketSpecification.java new file mode 100644 index 00000000000..38ca0e86743 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/WebSocketSpecification.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest.messages; + +import org.apache.flink.runtime.rest.HttpMethodWrapper; +import org.apache.flink.runtime.rest.handler.RestHandlerSpecification; + +/** + * Extended REST handler specification with websocket information. + * + * <p>Implementations must be state-less. + * + * @param <M> message parameters type + * @param <I> inbound message type + * @param <O> outbound message type + */ +public interface WebSocketSpecification<M extends MessageParameters, I extends RequestBody, O extends ResponseBody> extends RestHandlerSpecification { + + @Override + default HttpMethodWrapper getHttpMethod() { + return HttpMethodWrapper.GET; + } + + /** + * Returns the WebSocket subprotocol associated with the REST resource. + */ + String getSubprotocol(); + + /** + * Returns the base class of client-to-server messages. + * + * @return class of the message + */ + Class<I> getClientClass(); + + /** + * Returns the base class of server-to-client messages. + * + * @return class of the message + */ + Class<O> getServerClass(); + + /** + * Returns a new {@link MessageParameters} object. + * + * @return new message parameters object + */ + M getUnresolvedMessageParameters(); +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/websocket/KeyedChannelRouter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/websocket/KeyedChannelRouter.java new file mode 100644 index 00000000000..4fecafb4100 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/websocket/KeyedChannelRouter.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest.websocket; + +import org.apache.flink.shaded.netty4.io.netty.channel.Channel; +import org.apache.flink.shaded.netty4.io.netty.channel.group.ChannelGroup; +import org.apache.flink.shaded.netty4.io.netty.channel.group.ChannelGroupFuture; +import org.apache.flink.shaded.netty4.io.netty.util.AttributeKey; + +import javax.annotation.Nonnull; + +import java.util.Objects; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * Routes messages to channels based on a routing key. + * @param <K> the key type. + */ +public class KeyedChannelRouter<K> { + + private final AttributeKey<K> attributeKey; + + private final ChannelGroup channels; + + public KeyedChannelRouter(AttributeKey<K> attributeKey, ChannelGroup channelGroup) { + this.attributeKey = checkNotNull(attributeKey); + this.channels = checkNotNull(channelGroup); + } + + /** + * Registers a channel to receive messages for a given routing key. + * + * @param channel the channel to register. + */ + public void register(@Nonnull Channel channel, @Nonnull K routingKey) { + channel.attr(attributeKey).set(routingKey); + channels.add(channel); + } + + /** + * Unregisters a channel. + * + * @param channel the channel to unregister. + */ + public void unregister(@Nonnull Channel channel) { + channels.remove(channel); + } + + /** + * Writes and flushes an object to select channels based on a routing key. + * + * @param routingKey the key to select the target channel(s). + * @param o the object to write and flush. + */ + public ChannelGroupFuture write(@Nonnull K routingKey, @Nonnull Object o) { + return channels.write(o, channel -> isMatch(channel, routingKey)); + } + + /** + * Writes and flushes an object to select channels based on a routing key. + * + * @param routingKey the key to select the target channel(s). + * @param o the object to write and flush. + */ + public ChannelGroupFuture writeAndFlush(@Nonnull K routingKey, @Nonnull Object o) { + return channels.writeAndFlush(o, channel -> isMatch(channel, routingKey)); + } + + /** + * Closes all channels associated with the given routing key. + */ + public ChannelGroupFuture close(@Nonnull K routingKey) { + return channels.close(channel -> isMatch(channel, routingKey)); + } + + private boolean isMatch(Channel channel, K routingKey) { + return Objects.equals(routingKey, channel.attr(attributeKey).get()); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/websocket/WebSocket.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/websocket/WebSocket.java new file mode 100644 index 00000000000..77d53ef4427 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/websocket/WebSocket.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest.websocket; + +import org.apache.flink.runtime.rest.messages.RequestBody; +import org.apache.flink.runtime.rest.messages.ResponseBody; + +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFuture; + +/** + * A WebSocket for sending and receiving messages. + * + * @param <I> type of the server-to-client messages. + * @param <O> type of the client-to-server messages. + */ +public interface WebSocket<I extends ResponseBody, O extends RequestBody> { + + /** + * Adds a listener for websocket messages. + */ + void addListener(WebSocketListener<I> listener); + + /** + * Sends a message. + */ + ChannelFuture send(O message); + + /** + * Closes the websocket. + */ + ChannelFuture close(); +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/websocket/WebSocketListener.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/websocket/WebSocketListener.java new file mode 100644 index 00000000000..33410fc4fcc --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/websocket/WebSocketListener.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest.websocket; + +import org.apache.flink.runtime.rest.messages.ResponseBody; +import org.apache.flink.runtime.util.event.EventListener; + +/** + * A listener for WebSocket messages. + * + * @param <T> type of the server-to-client messages. + */ +public interface WebSocketListener<T extends ResponseBody> extends EventListener<T> { } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/rest/RestEndpointITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/rest/RestEndpointITCase.java index be5985f3a2c..0e4218c31ff 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/rest/RestEndpointITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/rest/RestEndpointITCase.java @@ -23,6 +23,7 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.rest.handler.AbstractRestHandler; +import org.apache.flink.runtime.rest.handler.AbstractWebSocketHandler; import org.apache.flink.runtime.rest.handler.HandlerRequest; import org.apache.flink.runtime.rest.handler.RestHandlerException; import org.apache.flink.runtime.rest.handler.RestHandlerSpecification; @@ -33,7 +34,10 @@ import org.apache.flink.runtime.rest.messages.MessageQueryParameter; import org.apache.flink.runtime.rest.messages.RequestBody; import org.apache.flink.runtime.rest.messages.ResponseBody; +import org.apache.flink.runtime.rest.messages.WebSocketSpecification; import org.apache.flink.runtime.rest.util.RestClientException; +import org.apache.flink.runtime.rest.websocket.KeyedChannelRouter; +import org.apache.flink.runtime.rest.websocket.WebSocket; import org.apache.flink.runtime.rpc.RpcUtils; import org.apache.flink.runtime.testingUtils.TestingUtils; import org.apache.flink.runtime.webmonitor.RestfulGateway; @@ -42,8 +46,14 @@ import org.apache.flink.util.Preconditions; import org.apache.flink.util.TestLogger; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFutureListener; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandler; +import org.apache.flink.shaded.netty4.io.netty.channel.group.ChannelGroup; +import org.apache.flink.shaded.netty4.io.netty.channel.group.DefaultChannelGroup; import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponseStatus; +import org.apache.flink.shaded.netty4.io.netty.util.AttributeKey; +import org.apache.flink.shaded.netty4.io.netty.util.concurrent.GlobalEventExecutor; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; @@ -55,11 +65,14 @@ import javax.annotation.Nonnull; import java.net.InetSocketAddress; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; +import java.util.concurrent.LinkedBlockingQueue; import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; @@ -78,6 +91,8 @@ private RestServerEndpoint serverEndpoint; private RestClient clientEndpoint; + private TestEventProvider eventProvider; + @Before public void setup() throws Exception { Configuration config = new Configuration(); @@ -91,12 +106,22 @@ public void setup() throws Exception { GatewayRetriever<RestfulGateway> mockGatewayRetriever = mock(GatewayRetriever.class); when(mockGatewayRetriever.getNow()).thenReturn(Optional.of(mockRestfulGateway)); + // a REST operation TestHandler testHandler = new TestHandler( CompletableFuture.completedFuture(restAddress), mockGatewayRetriever, RpcUtils.INF_TIMEOUT); - serverEndpoint = new TestRestServerEndpoint(serverConfig, testHandler); + // a WebSocket operation + ChannelGroup eventChannelGroup = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE); + eventProvider = new TestEventProvider(eventChannelGroup); + TestWebSocketOperation.WsRestHandler testWebSocketHandler = new TestWebSocketOperation.WsRestHandler( + CompletableFuture.completedFuture(restAddress), + mockGatewayRetriever, + eventProvider, + RpcUtils.INF_TIMEOUT); + + serverEndpoint = new TestRestServerEndpoint(serverConfig, testHandler, testWebSocketHandler); clientEndpoint = new TestRestClient(clientConfig); serverEndpoint.start(); @@ -194,19 +219,74 @@ public void testBadHandlerRequest() throws Exception { } } + /** + * Tests that a web socket operation works end-to-end. + */ + @Test + public void testWebSocketEndToEnd() throws Exception { + TestWebSocketOperation.WsParameters parameters = new TestWebSocketOperation.WsParameters(); + parameters.jobIDPathParameter.resolve(PATH_JOB_ID); + + final LinkedBlockingQueue<TestMessage> messageQueue = new LinkedBlockingQueue<>(); + final InetSocketAddress serverAddress = serverEndpoint.getServerAddress(); + + // open a websocket connection with a listener that simply enqueues incoming messages + CompletableFuture<WebSocket<TestMessage, TestMessage>> response = clientEndpoint.sendWebSocketRequest( + serverAddress.getHostName(), + serverAddress.getPort(), + new TestWebSocketOperation.WsHeaders(), + parameters, + (event) -> { + try { + messageQueue.put(event); + } + catch (Exception e) { + throw new RuntimeException(e); + } + }); + + // wait for the connection to be established + WebSocket<TestMessage, TestMessage> webSocket = response.get(); + try { + // wait for the server to register the channel (happens asynchronously after handshake complete) + TestWebSocketOperation.WsRestHandler.LATCH.await(); + + // play ping-pong (server goes first) + eventProvider.writeAndFlush(PATH_JOB_ID, new TestMessage(42)); + TestMessage received = messageQueue.take(); + Assert.assertEquals(42, received.sequenceNumber); + webSocket.send(received.incr()); + received = messageQueue.take(); + Assert.assertEquals(44, received.sequenceNumber); + } + finally { + webSocket.close(); + } + } + private static class TestRestServerEndpoint extends RestServerEndpoint { private final TestHandler testHandler; + private final TestWebSocketOperation.WsRestHandler testWebSocketHandler; - TestRestServerEndpoint(RestServerEndpointConfiguration configuration, TestHandler testHandler) { + TestRestServerEndpoint(RestServerEndpointConfiguration configuration, TestHandler testHandler, TestWebSocketOperation.WsRestHandler testWebSocketHandler) { super(configuration); this.testHandler = Preconditions.checkNotNull(testHandler); + this.testWebSocketHandler = Preconditions.checkNotNull(testWebSocketHandler); } @Override protected Collection<Tuple2<RestHandlerSpecification, ChannelInboundHandler>> initializeHandlers(CompletableFuture<String> restAddressFuture) { - return Collections.singleton(Tuple2.of(new TestHeaders(), testHandler)); + return Arrays.asList( + Tuple2.of(new TestHeaders(), testHandler), + Tuple2.of(new TestWebSocketOperation.WsHeaders(), testWebSocketHandler)); + } + } + + private static class TestEventProvider extends KeyedChannelRouter<JobID> { + public TestEventProvider(ChannelGroup channelGroup) { + super(AttributeKey.valueOf("jobID"), channelGroup); } } @@ -325,6 +405,88 @@ public TestParameters getUnresolvedMessageParameters() { } } + private static class TestWebSocketOperation { + + private static class WsParameters extends MessageParameters { + private final JobIDPathParameter jobIDPathParameter = new JobIDPathParameter(); + + @Override + public Collection<MessagePathParameter<?>> getPathParameters() { + return Collections.singleton(jobIDPathParameter); + } + + @Override + public Collection<MessageQueryParameter<?>> getQueryParameters() { + return Collections.emptyList(); + } + } + + static class WsHeaders implements WebSocketSpecification<WsParameters, TestMessage, TestMessage> { + + @Override + public String getTargetRestEndpointURL() { + return "/test/:jobid/subscribe"; + } + + @Override + public String getSubprotocol() { + return "test"; + } + + @Override + public Class<TestMessage> getServerClass() { + return TestMessage.class; + } + + @Override + public Class<TestMessage> getClientClass() { + return TestMessage.class; + } + + @Override + public WsParameters getUnresolvedMessageParameters() { + return new WsParameters(); + } + } + + static class WsRestHandler extends AbstractWebSocketHandler<RestfulGateway, WsParameters, TestMessage, TestMessage> { + + public static final CountDownLatch LATCH = new CountDownLatch(1); + + private final TestEventProvider eventProvider; + + WsRestHandler( + CompletableFuture<String> localAddressFuture, + GatewayRetriever<RestfulGateway> leaderRetriever, + TestEventProvider eventProvider, + Time timeout) { + super(localAddressFuture, leaderRetriever, timeout, new WsHeaders()); + this.eventProvider = eventProvider; + } + + @Override + protected CompletableFuture<Void> handshakeInitiated(ChannelHandlerContext ctx, WsParameters parameters) throws Exception { + // validate request before completing the handshake + JobID jobID = parameters.getPathParameter(JobIDPathParameter.class); + Assert.assertEquals(PATH_JOB_ID, jobID); + return CompletableFuture.completedFuture(null); + } + + @Override + protected void handshakeCompleted(ChannelHandlerContext ctx, WsParameters parameters) throws Exception { + // handshake complete; register for server events + JobID jobID = parameters.getPathParameter(JobIDPathParameter.class); + eventProvider.register(ctx.channel(), jobID); + LATCH.countDown(); + } + + @Override + protected void messageReceived(ChannelHandlerContext ctx, WsParameters parameters, TestMessage msg) throws Exception { + ctx.channel().writeAndFlush(msg.incr()).addListener(ChannelFutureListener.CLOSE_ON_FAILURE); + } + } + } + static class JobIDPathParameter extends MessagePathParameter<JobID> { JobIDPathParameter() { super(JOB_ID_KEY); @@ -373,4 +535,17 @@ public String convertStringToValue(JobID value) { return value.toString(); } } + + static class TestMessage implements ResponseBody, RequestBody { + public final int sequenceNumber; + + @JsonCreator + public TestMessage(@JsonProperty("sequenceNumber") int sequenceNumber) { + this.sequenceNumber = sequenceNumber; + } + + public TestMessage incr() { + return new TestMessage(sequenceNumber + 1); + } + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/rest/handler/AbstractWebSocketHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/rest/handler/AbstractWebSocketHandlerTest.java new file mode 100644 index 00000000000..951ecfa8899 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/rest/handler/AbstractWebSocketHandlerTest.java @@ -0,0 +1,456 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest.handler; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.time.Time; +import org.apache.flink.runtime.rest.messages.MessageParameters; +import org.apache.flink.runtime.rest.messages.MessagePathParameter; +import org.apache.flink.runtime.rest.messages.MessageQueryParameter; +import org.apache.flink.runtime.rest.messages.RequestBody; +import org.apache.flink.runtime.rest.messages.ResponseBody; +import org.apache.flink.runtime.rest.messages.WebSocketSpecification; +import org.apache.flink.runtime.rpc.RpcUtils; +import org.apache.flink.runtime.webmonitor.RestfulGateway; +import org.apache.flink.runtime.webmonitor.retriever.GatewayRetriever; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.channel.Channel; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandler; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInitializer; +import org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.FullHttpRequest; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpHeaders; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpRequestDecoder; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponse; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponseEncoder; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponseStatus; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.router.Handler; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.router.Router; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker13; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.websocketx.WebSocketVersion; +import org.apache.flink.shaded.netty4.io.netty.util.AbstractReferenceCounted; +import org.apache.flink.shaded.netty4.io.netty.util.ReferenceCountUtil; + +import com.fasterxml.jackson.annotation.JsonProperty; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.net.URI; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.LinkedBlockingQueue; + +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Tests the {@link AbstractWebSocketHandler}. + */ +public class AbstractWebSocketHandlerTest { + + private static final String TEST_SUBPROTOCOL = "test"; + private static final String REST_ADDRESS = "http://localhost:1234"; + private static final JobID TEST_JOB_ID = new JobID(); + + private GatewayRetriever<RestfulGateway> gatewayRetriever; + + @Before + public void setup() throws Exception { + // setup a mock gateway and retriever to satisfy the RedirectHandler + RestfulGateway mockRestfulGateway = mock(RestfulGateway.class); + when(mockRestfulGateway.requestRestAddress(any(Time.class))).thenReturn(CompletableFuture.completedFuture(REST_ADDRESS)); + gatewayRetriever = mock(GatewayRetriever.class); + when(gatewayRetriever.getNow()).thenReturn(Optional.of(mockRestfulGateway)); + } + + /** + * Tests the parameter processing logic. + */ + @Test + public void testInvalidMessageParameter() { + + EmbeddedChannel channel = channel(TestSpecification.INSTANCE, new TestHandler()); + + // write a handshake request that is lacking a mandatory parameter + TestParameters params = TestParameters.from(TEST_JOB_ID, 42); + URI requestUri = URI.create(requestUri(TestSpecification.INSTANCE, params).toString() + .replace("42", "NA")); + FullHttpRequest request = handshakeRequest(requestUri, TEST_SUBPROTOCOL); + channel.writeInbound(request); + + // verify that the request was rejected + HttpResponse response = (HttpResponse) channel.readOutbound(); + Assert.assertEquals(0, request.refCnt()); + Assert.assertEquals(HttpResponseStatus.BAD_REQUEST, response.getStatus()); + Assert.assertTrue(channel.isOpen()); + } + + /** + * Tests the parameter processing logic. + */ + @Test + public void testMissingMessageParameter() { + + EmbeddedChannel channel = channel(TestSpecification.INSTANCE, new TestHandler()); + + // write a handshake request that is lacking a mandatory parameter + TestParameters params = TestParameters.from(TEST_JOB_ID, 42); + URI requestUri = URI.create(requestUri(TestSpecification.INSTANCE, params).toString() + .replace("q", "_")); + FullHttpRequest request = handshakeRequest(requestUri, TEST_SUBPROTOCOL); + channel.writeInbound(request); + + // verify that the request was rejected + HttpResponse response = (HttpResponse) channel.readOutbound(); + Assert.assertEquals(0, request.refCnt()); + Assert.assertEquals(HttpResponseStatus.BAD_REQUEST, response.getStatus()); + Assert.assertTrue(channel.isOpen()); + } + + /** + * Tests the parameter processing logic. + */ + @Test + public void testValidMessageParameters() { + + TestHandler handler = new TestHandler(); + EmbeddedChannel channel = channel(TestSpecification.INSTANCE, handler); + + // write a handshake request with the mandatory parameter + TestParameters params = TestParameters.from(TEST_JOB_ID, 42); + FullHttpRequest request = handshakeRequest(requestUri(TestSpecification.INSTANCE, params), TEST_SUBPROTOCOL); + channel.writeInbound(request); + + // verify that the request was accepted and that the message parameters were available in numerous callbacks + HttpResponse response = (HttpResponse) channel.readOutbound(); + Assert.assertEquals(0, request.refCnt()); + Assert.assertEquals(HttpResponseStatus.SWITCHING_PROTOCOLS, response.getStatus()); + Assert.assertNotNull(handler.handshakeInitiated); + Assert.assertEquals(Collections.singletonList(42), handler.handshakeInitiated.queryParameter.getValue()); + } + + /** + * Tests the websocket upgrade logic. + */ + @Test + public void testWebSocketUpgrade() { + + EmbeddedChannel channel = channel(TestSpecification.INSTANCE, new TestHandler()); + + // write a handshake request + TestParameters params = TestParameters.from(TEST_JOB_ID, 42); + FullHttpRequest request = handshakeRequest(requestUri(TestSpecification.INSTANCE, params), TEST_SUBPROTOCOL); + channel.writeInbound(request); + + // check for a websocket handshaker and message handler in the pipeline + Assert.assertNotNull(channel.pipeline().get(WebSocketServerProtocolHandler.class)); + Assert.assertNotNull(channel.pipeline().get(JsonWebSocketMessageCodec.class)); + + // check for a websocket handshake response + HttpResponse response = (HttpResponse) channel.readOutbound(); + Assert.assertEquals(0, request.refCnt()); + Assert.assertNotNull(response); + Assert.assertEquals(HttpResponseStatus.SWITCHING_PROTOCOLS, response.getStatus()); + } + + /** + * Tests upgrade failure due to a handshaking issue. + */ + @Test + public void testWebSocketUpgradeFailure() { + + EmbeddedChannel channel = channel(TestSpecification.INSTANCE, new TestHandler()); + + // write an invalid handshake request + TestParameters params = TestParameters.from(TEST_JOB_ID, 42); + FullHttpRequest request = handshakeRequest(requestUri(TestSpecification.INSTANCE, params), TEST_SUBPROTOCOL); + request.headers().remove("Sec-WebSocket-Key"); + channel.writeInbound(request); + + // check for a websocket handshake response + HttpResponse response = (HttpResponse) channel.readOutbound(); + Assert.assertEquals(0, request.refCnt()); + Assert.assertNotNull(response); + Assert.assertEquals(HttpResponseStatus.BAD_REQUEST, response.getStatus()); + Assert.assertFalse(channel.isOpen()); + } + + /** + * Tests message reading and writing. + */ + @Test + public void testMessageReadWrite() throws Exception { + TestHandler handler = new TestHandler(); + EmbeddedChannel channel = channel(TestSpecification.INSTANCE, handler); + + TestParameters params = TestParameters.from(TEST_JOB_ID, 42); + FullHttpRequest request = handshakeRequest(requestUri(TestSpecification.INSTANCE, params), TEST_SUBPROTOCOL); + channel.writeInbound(request); + HttpResponse response = (HttpResponse) channel.readOutbound(); + Assert.assertEquals(HttpResponseStatus.SWITCHING_PROTOCOLS, response.getStatus()); + + ClientMessage expected = new ClientMessage(42); + channel.writeInbound(expected); + ClientMessage actual = handler.messages.take(); + Assert.assertEquals(expected.sequenceNumber, actual.sequenceNumber); + Assert.assertEquals(1, expected.refCnt()); + Assert.assertTrue(channel.isOpen()); + + ByteBuf msg = (ByteBuf) channel.readOutbound(); + Assert.assertEquals(23, msg.readableBytes()); // contains an on-the-wire text frame + ReferenceCountUtil.release(msg); + } + + /** + * A path parameter for test purposes. + */ + static class TestPathParameter extends MessagePathParameter<JobID> { + TestPathParameter() { + super("jobid"); + } + + @Override + public JobID convertFromString(String value) { + return JobID.fromHexString(value); + } + + @Override + protected String convertToString(JobID value) { + return value.toString(); + } + } + + /** + * A query parameter for test purposes. + */ + private static class TestQueryParameter extends MessageQueryParameter<Integer> { + + TestQueryParameter() { + super("q", MessageParameterRequisiteness.MANDATORY); + } + + @Override + public Integer convertValueFromString(String value) { + return Integer.parseInt(value); + } + + @Override + public String convertStringToValue(Integer value) { + return value.toString(); + } + } + + /** + * Parameters for the test WebSocket resource. + */ + private static class TestParameters extends MessageParameters { + private final TestPathParameter pathParameter = new TestPathParameter(); + private final TestQueryParameter queryParameter = new TestQueryParameter(); + + @Override + public Collection<MessagePathParameter<?>> getPathParameters() { + return Collections.singleton(pathParameter); + } + + @Override + public Collection<MessageQueryParameter<?>> getQueryParameters() { + return Collections.singleton(queryParameter); + } + + public static TestParameters from(JobID jobID, int sequenceNumber) { + TestParameters p = new TestParameters(); + p.pathParameter.resolve(jobID); + p.queryParameter.resolve(Collections.singletonList(sequenceNumber)); + return p; + } + } + + /** + * The specification for the test WebSocket resource. + */ + static class TestSpecification implements WebSocketSpecification<TestParameters, ClientMessage, ServerMessage> { + + static final TestSpecification INSTANCE = new TestSpecification(); + + @Override + public String getTargetRestEndpointURL() { + return "/jobs/:jobid/subscribe"; + } + + @Override + public String getSubprotocol() { + return "test"; + } + + @Override + public Class<ServerMessage> getServerClass() { + return ServerMessage.class; + } + + @Override + public Class<ClientMessage> getClientClass() { + return ClientMessage.class; + } + + @Override + public TestParameters getUnresolvedMessageParameters() { + return new TestParameters(); + } + } + + /** + * The channel handler for the test WebSocket resource. + */ + private class TestHandler extends AbstractWebSocketHandler<RestfulGateway, TestParameters, ClientMessage, ServerMessage> { + + TestParameters handshakeInitiated = null; + + TestParameters handshakeCompleted = null; + + final LinkedBlockingQueue<ClientMessage> messages = new LinkedBlockingQueue<>(); + + public TestHandler() { + super(CompletableFuture.completedFuture(REST_ADDRESS), gatewayRetriever, RpcUtils.INF_TIMEOUT, TestSpecification.INSTANCE); + } + + @Override + protected CompletableFuture<Void> handshakeInitiated(ChannelHandlerContext ctx, TestParameters parameters) throws Exception { + handshakeInitiated = parameters; + return CompletableFuture.completedFuture(null); + } + + @Override + protected void handshakeCompleted(ChannelHandlerContext ctx, TestParameters parameters) throws Exception { + handshakeCompleted = parameters; + } + + @Override + protected void messageReceived(ChannelHandlerContext ctx, TestParameters parameters, ClientMessage msg) throws Exception { + ReferenceCountUtil.retain(msg); + messages.put(msg); + + ctx.channel().writeAndFlush(new ServerMessage(msg.sequenceNumber)); + } + } + + /** + * A WebSocket message for test purposes. + */ + private static class ClientMessage extends AbstractReferenceCounted implements RequestBody { + public final int sequenceNumber; + + public ClientMessage(@JsonProperty("sequenceNumber") int sequenceNumber) { + this.sequenceNumber = sequenceNumber; + } + + @Override + protected void deallocate() { + } + } + + /** + * A WebSocket message for test purposes. + */ + private static class ServerMessage extends AbstractReferenceCounted implements ResponseBody { + public final int sequenceNumber; + + public ServerMessage(@JsonProperty("sequenceNumber") int sequenceNumber) { + this.sequenceNumber = sequenceNumber; + } + + @Override + protected void deallocate() { + } + } + + // ----- Utility ----- + + /** + * Creates an embedded channel for HTTP/websocket test purposes. + */ + private EmbeddedChannel channel(RestHandlerSpecification spec, ChannelInboundHandler handler) { + + // the websocket handshaker looks for HTTP decoder/encoder (to know where to inject WS handlers). + // For test purposes, use a passthru encoder so that tests may assert various aspects of the response. + return new EmbeddedChannel(new ChannelInitializer<Channel>() { + @Override + protected void initChannel(Channel channel) throws Exception { + Router router = new Router(); + router.GET(spec.getTargetRestEndpointURL(), handler); + Handler routerHandler = new Handler(router); + channel.pipeline() + .addLast(new HttpRequestDecoder(), new MockHttpResponseEncoder()) + .addLast(routerHandler.name(), routerHandler); + } + }); + } + + /** + * Creates an URI for the given REST spec. + */ + private static URI requestUri(RestHandlerSpecification spec, MessageParameters parameters) { + URI resolved = URI.create(MessageParameters.resolveUrl(spec.getTargetRestEndpointURL(), parameters)); + return URI.create(REST_ADDRESS).resolve(resolved); + } + + /** + * Creates a handshake request for the given request URI. + */ + private static FullHttpRequest handshakeRequest(URI requestUri, String subprotocol) { + URI wsURL = URI.create(requestUri.toString().replace("http:", "ws:").replace("https:", "wss:")); + MockHandshaker handshaker = new MockHandshaker(wsURL, subprotocol, HttpHeaders.EMPTY_HEADERS); + FullHttpRequest httpRequest = handshaker.newHandshakeRequest(); + return httpRequest; + } + + /** + * A helper class for using Netty's built-in handshaking class to construct valid handshake requests. + */ + private static class MockHandshaker extends WebSocketClientHandshaker13 { + public MockHandshaker(URI webSocketURL, String subprotocol, HttpHeaders customHeaders) { + super(webSocketURL, WebSocketVersion.V13, subprotocol, false, customHeaders, Integer.MAX_VALUE); + } + + @Override + public FullHttpRequest newHandshakeRequest() { + return super.newHandshakeRequest(); + } + } + + /** + * A no-op encoder that must exist to satisfy the WS handshaker. + */ + private static class MockHttpResponseEncoder extends HttpResponseEncoder { + @Override + protected void encode(ChannelHandlerContext ctx, Object msg, List<Object> out) throws Exception { + if (msg != null) { + ReferenceCountUtil.retain(msg); + out.add(msg); + } + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/rest/handler/JsonWebSocketMessageCodecTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/rest/handler/JsonWebSocketMessageCodecTest.java new file mode 100644 index 00000000000..48283f34bb5 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/rest/handler/JsonWebSocketMessageCodecTest.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest.handler; + +import org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.DecoderException; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.EncoderException; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.websocketx.TextWebSocketFrame; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.junit.Assert; +import org.junit.Test; + +/** + * Tests the {@link JsonWebSocketMessageCodec}. + */ +public class JsonWebSocketMessageCodecTest { + + @Test + public void testDecodeSuccess() { + EmbeddedChannel channel = new EmbeddedChannel(new JsonWebSocketMessageCodec<>(TestMessage.class, TestMessage.class)); + channel.writeInbound(new TextWebSocketFrame("{\"sequenceNumber\":42}")); + TestMessage actual = (TestMessage) channel.readInbound(); + Assert.assertNotNull(actual); + Assert.assertEquals(42, actual.sequenceNumber); + } + + @Test(expected = DecoderException.class) + public void testDecodeFailure() { + EmbeddedChannel channel = new EmbeddedChannel(new JsonWebSocketMessageCodec<>(TestMessage.class, TestMessage.class)); + channel.writeInbound(new TextWebSocketFrame("")); + } + + @Test + public void testEncodeSuccess() { + EmbeddedChannel channel = new EmbeddedChannel(new JsonWebSocketMessageCodec<>(TestMessage.class, TestMessage.class)); + channel.writeOutbound(new TestMessage(42)); + TextWebSocketFrame actual = (TextWebSocketFrame) channel.readOutbound(); + Assert.assertNotNull(actual); + Assert.assertEquals("{\"sequenceNumber\":42}", actual.text()); + } + + @Test(expected = EncoderException.class) + public void testEncodeFailure() { + EmbeddedChannel channel = new EmbeddedChannel(new JsonWebSocketMessageCodec<>(TestMessage.class, TestMessage.class)); + channel.writeOutbound(new TestMessage(-1)); + } + + private static class TestMessage { + + private final int sequenceNumber; + + public int getSequenceNumber() { + if (sequenceNumber < 0) { + throw new IllegalStateException(); + } + return sequenceNumber; + } + + @JsonCreator + public TestMessage(@JsonProperty("sequenceNumber") int sequenceNumber) { + this.sequenceNumber = sequenceNumber; + } + } +} ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org > Create WebSocket handler (server) > --------------------------------- > > Key: FLINK-7738 > URL: https://issues.apache.org/jira/browse/FLINK-7738 > Project: Flink > Issue Type: Sub-task > Components: Cluster Management, Mesos > Reporter: Eron Wright > Assignee: Eron Wright > Priority: Major > Labels: pull-request-available > > An abstract handler is needed to support websocket communication. -- This message was sent by Atlassian JIRA (v7.6.3#76005)