[ 
https://issues.apache.org/jira/browse/FLINK-7738?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16556257#comment-16556257
 ] 

ASF GitHub Bot commented on FLINK-7738:
---------------------------------------

EronWright closed pull request #4767: [FLINK-7738] [flip-6] Create WebSocket 
handler (server, client)
URL: https://github.com/apache/flink/pull/4767
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/flink-runtime-web/src/test/java/org/apache/flink/runtime/webmonitor/RedirectHandlerTest.java
 
b/flink-runtime-web/src/test/java/org/apache/flink/runtime/webmonitor/RedirectHandlerTest.java
index 4808781c7b8..cf465294f5a 100644
--- 
a/flink-runtime-web/src/test/java/org/apache/flink/runtime/webmonitor/RedirectHandlerTest.java
+++ 
b/flink-runtime-web/src/test/java/org/apache/flink/runtime/webmonitor/RedirectHandlerTest.java
@@ -29,17 +29,24 @@
 import org.apache.flink.util.TestLogger;
 
 import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext;
+import 
org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandlerAdapter;
+import 
org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel;
+import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.DefaultFullHttpRequest;
+import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpMethod;
 import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponse;
 import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponseStatus;
+import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpVersion;
 import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.router.KeepAliveWrite;
 import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.router.Routed;
 import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.router.Router;
+import org.apache.flink.shaded.netty4.io.netty.util.ReferenceCountUtil;
 
 import org.junit.Assert;
 import org.junit.Test;
 
 import javax.annotation.Nonnull;
 
+import java.util.HashMap;
 import java.util.Optional;
 import java.util.concurrent.CompletableFuture;
 
@@ -137,6 +144,35 @@ public void testRedirectHandler() throws Exception {
                }
        }
 
+       /**
+        * Tests the approach of using the redirect handler as a standalone 
handler.
+        */
+       @Test
+       public void testUserEvent() {
+               final String correctAddress = "foobar:21345";
+               final CompletableFuture<String> localAddressFuture = 
CompletableFuture.completedFuture(correctAddress);
+               final Time timeout = Time.seconds(10L);
+
+               final RestfulGateway localGateway = mock(RestfulGateway.class);
+               
when(localGateway.requestRestAddress(any(Time.class))).thenReturn(CompletableFuture.completedFuture(correctAddress));
+               final GatewayRetriever<RestfulGateway> gatewayRetriever = 
mock(GatewayRetriever.class);
+               
when(gatewayRetriever.getNow()).thenReturn(Optional.of(localGateway));
+
+               final RedirectHandler<RestfulGateway> redirectHandler = new 
RedirectHandler<>(
+                       localAddressFuture,
+                       gatewayRetriever,
+                       timeout);
+               final UserEventHandler eventHandler = new UserEventHandler();
+               EmbeddedChannel channel = new EmbeddedChannel(redirectHandler, 
eventHandler);
+
+               // write a (routed) HTTP request, then validate that a user 
event was propagated
+               DefaultFullHttpRequest request = new 
DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
+               Routed routed = new Routed(null, false, request, "/", new 
HashMap<>(), new HashMap<>());
+               channel.writeInbound(routed);
+               Assert.assertNotNull(eventHandler.gateway);
+               Assert.assertNotNull(eventHandler.routed);
+       }
+
        private static class TestingHandler extends 
RedirectHandler<RestfulGateway> {
 
                protected TestingHandler(
@@ -154,4 +190,25 @@ protected void respondAsLeader(ChannelHandlerContext 
channelHandlerContext, Rout
                }
        }
 
+       private static class UserEventHandler<T extends RestfulGateway> extends 
ChannelInboundHandlerAdapter {
+
+               public volatile T gateway;
+
+               public volatile Routed routed;
+
+               @Override
+               @SuppressWarnings("unchecked")
+               public void userEventTriggered(ChannelHandlerContext ctx, 
Object evt) throws Exception {
+                       if (evt instanceof RedirectHandler.GatewayRetrieved) {
+                               gateway = 
((RedirectHandler.GatewayRetrieved<T>) evt).getGateway();
+                       }
+                       super.userEventTriggered(ctx, evt);
+               }
+
+               @Override
+               public void channelRead(ChannelHandlerContext ctx, Object msg) 
throws Exception {
+                       routed = (Routed) msg;
+                       ReferenceCountUtil.release(msg);
+               }
+       }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestClient.java 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestClient.java
index 3fcb85c5096..9c1de986ba0 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestClient.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestClient.java
@@ -20,6 +20,7 @@
 
 import org.apache.flink.api.common.time.Time;
 import org.apache.flink.configuration.ConfigConstants;
+import org.apache.flink.runtime.rest.handler.JsonWebSocketMessageCodec;
 import org.apache.flink.runtime.rest.handler.PipelineErrorHandler;
 import org.apache.flink.runtime.rest.messages.EmptyMessageParameters;
 import org.apache.flink.runtime.rest.messages.EmptyRequestBody;
@@ -28,9 +29,12 @@
 import org.apache.flink.runtime.rest.messages.MessageParameters;
 import org.apache.flink.runtime.rest.messages.RequestBody;
 import org.apache.flink.runtime.rest.messages.ResponseBody;
+import org.apache.flink.runtime.rest.messages.WebSocketSpecification;
 import org.apache.flink.runtime.rest.util.RestClientException;
 import org.apache.flink.runtime.rest.util.RestConstants;
 import org.apache.flink.runtime.rest.util.RestMapperUtils;
+import org.apache.flink.runtime.rest.websocket.WebSocket;
+import org.apache.flink.runtime.rest.websocket.WebSocketListener;
 import org.apache.flink.util.FlinkRuntimeException;
 import org.apache.flink.util.Preconditions;
 
@@ -38,6 +42,7 @@
 import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
 import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBufInputStream;
 import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled;
+import org.apache.flink.shaded.netty4.io.netty.channel.Channel;
 import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFuture;
 import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext;
 import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInitializer;
@@ -46,6 +51,7 @@
 import org.apache.flink.shaded.netty4.io.netty.channel.socket.SocketChannel;
 import 
org.apache.flink.shaded.netty4.io.netty.channel.socket.nio.NioSocketChannel;
 import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.DefaultFullHttpRequest;
+import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.DefaultHttpHeaders;
 import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.FullHttpRequest;
 import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.FullHttpResponse;
 import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpClientCodec;
@@ -54,7 +60,10 @@
 import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponse;
 import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponseStatus;
 import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpVersion;
+import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
+import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.websocketx.WebSocketVersion;
 import org.apache.flink.shaded.netty4.io.netty.handler.ssl.SslHandler;
+import org.apache.flink.shaded.netty4.io.netty.util.ReferenceCountUtil;
 import 
org.apache.flink.shaded.netty4.io.netty.util.concurrent.DefaultThreadFactory;
 
 import com.fasterxml.jackson.core.JsonParseException;
@@ -69,7 +78,11 @@
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.StringWriter;
+import java.net.URI;
+import java.util.Arrays;
+import java.util.List;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CopyOnWriteArrayList;
 import java.util.concurrent.Executor;
 import java.util.concurrent.TimeUnit;
 
@@ -81,38 +94,23 @@
 
        private static final ObjectMapper objectMapper = 
RestMapperUtils.getStrictObjectMapper();
 
+       private final RestClientConfiguration configuration;
+
        // used to open connections to a rest server endpoint
        private final Executor executor;
 
        private Bootstrap bootstrap;
 
        public RestClient(RestClientConfiguration configuration, Executor 
executor) {
-               Preconditions.checkNotNull(configuration);
+               this.configuration = Preconditions.checkNotNull(configuration);
                this.executor = Preconditions.checkNotNull(executor);
 
-               SSLEngine sslEngine = configuration.getSslEngine();
-               ChannelInitializer<SocketChannel> initializer = new 
ChannelInitializer<SocketChannel>() {
-                       @Override
-                       protected void initChannel(SocketChannel socketChannel) 
throws Exception {
-                               // SSL should be the first handler in the 
pipeline
-                               if (sslEngine != null) {
-                                       socketChannel.pipeline().addLast("ssl", 
new SslHandler(sslEngine));
-                               }
-
-                               socketChannel.pipeline()
-                                       .addLast(new HttpClientCodec())
-                                       .addLast(new HttpObjectAggregator(1024 
* 1024))
-                                       .addLast(new ClientHandler())
-                                       .addLast(new PipelineErrorHandler(LOG));
-                       }
-               };
                NioEventLoopGroup group = new NioEventLoopGroup(1, new 
DefaultThreadFactory("flink-rest-client-netty"));
 
                bootstrap = new Bootstrap();
                bootstrap
                        .group(group)
-                       .channel(NioSocketChannel.class)
-                       .handler(initializer);
+                       .channel(NioSocketChannel.class);
 
                LOG.info("Rest client endpoint started.");
        }
@@ -175,7 +173,17 @@ public void shutdown(Time timeout) {
        }
 
        private <P extends ResponseBody> CompletableFuture<P> 
submitRequest(String targetAddress, int targetPort, FullHttpRequest 
httpRequest, Class<P> responseClass) {
-               return CompletableFuture.supplyAsync(() -> 
bootstrap.connect(targetAddress, targetPort), executor)
+               Bootstrap bootstrap1 = bootstrap.clone().handler(new 
ClientBootstrap() {
+                       @Override
+                       protected void initChannel(SocketChannel channel) 
throws Exception {
+                               super.initChannel(channel);
+                               channel.pipeline()
+                                       .addLast(new ClientHandler())
+                                       .addLast(new PipelineErrorHandler(LOG));
+                       }
+               });
+
+               return CompletableFuture.supplyAsync(() -> 
bootstrap1.connect(targetAddress, targetPort), executor)
                        .thenApply((channel) -> {
                                try {
                                        return channel.sync();
@@ -221,6 +229,21 @@ public void shutdown(Time timeout) {
                return responseFuture;
        }
 
+       private class ClientBootstrap extends ChannelInitializer<SocketChannel> 
{
+               @Override
+               protected void initChannel(SocketChannel channel) throws 
Exception {
+                       SSLEngine sslEngine = configuration.getSslEngine();
+
+                       // SSL should be the first handler in the pipeline
+                       if (sslEngine != null) {
+                               channel.pipeline().addLast("ssl", new 
SslHandler(sslEngine));
+                       }
+                       channel.pipeline()
+                               .addLast(new HttpClientCodec())
+                               .addLast(new HttpObjectAggregator(1024 * 1024));
+               }
+       }
+
        private static class ClientHandler extends 
SimpleChannelInboundHandler<Object> {
 
                private final CompletableFuture<JsonResponse> jsonFuture = new 
CompletableFuture<>();
@@ -300,4 +323,116 @@ public HttpResponseStatus getHttpResponseStatus() {
                        return httpResponseStatus;
                }
        }
+
+       //M messageHeaders
+       @SuppressWarnings("unchecked")
+       public <M extends WebSocketSpecification<U, O, I>, U extends 
MessageParameters, I extends ResponseBody, O extends RequestBody> 
CompletableFuture<WebSocket<I, O>> sendWebSocketRequest(String targetAddress, 
int targetPort, M spec, U messageParameters, WebSocketListener<I>... listeners) 
throws IOException {
+               Preconditions.checkNotNull(targetAddress);
+               Preconditions.checkArgument(0 <= targetPort && targetPort < 
65536, "The target port " + targetPort + " is not in the range (0, 65536].");
+               Preconditions.checkNotNull(spec);
+               Preconditions.checkNotNull(messageParameters);
+               Preconditions.checkState(messageParameters.isResolved(), 
"Message parameters were not resolved.");
+
+               String targetUrl = 
MessageParameters.resolveUrl(spec.getTargetRestEndpointURL(), 
messageParameters);
+               URI webSocketURL = URI.create("ws://" + targetAddress + ":" + 
targetPort).resolve(targetUrl);
+               LOG.debug("Sending WebSocket request to {}", webSocketURL);
+
+               final HttpHeaders headers = new DefaultHttpHeaders()
+                       .add(HttpHeaders.Names.CONTENT_TYPE, 
RestConstants.REST_CONTENT_TYPE);
+
+               Bootstrap bootstrap1 = bootstrap.clone().handler(new 
ClientBootstrap() {
+                       @Override
+                       protected void initChannel(SocketChannel channel) 
throws Exception {
+                               super.initChannel(channel);
+                               channel.pipeline()
+                                       .addLast(new 
WebSocketClientProtocolHandler(webSocketURL, WebSocketVersion.V13, 
spec.getSubprotocol(), false, headers, 65535))
+                                       .addLast(new 
JsonWebSocketMessageCodec<>(spec.getServerClass(), spec.getClientClass()))
+                                       .addLast(new 
WsResponseHandler<>(channel, spec.getServerClass(), spec.getClientClass(), 
listeners));
+                       }
+               });
+
+               return CompletableFuture.supplyAsync(() -> 
bootstrap1.connect(targetAddress, targetPort), executor)
+                       .thenApply((channel) -> {
+                               try {
+                                       return channel.sync();
+                               } catch (InterruptedException e) {
+                                       throw new FlinkRuntimeException(e);
+                               }
+                       })
+                       .thenApply((ChannelFuture::channel))
+                       .thenCompose(channel -> {
+                               WsResponseHandler<I, O> handler = 
channel.pipeline().get(WsResponseHandler.class);
+                               return handler.getWebSocketFuture();
+                       });
+       }
+
+       private static class WsResponseHandler<I extends ResponseBody, O 
extends RequestBody> extends SimpleChannelInboundHandler<I> implements 
WebSocket<I, O> {
+
+               private final Channel channel;
+               private final List<WebSocketListener<I>> listeners = new 
CopyOnWriteArrayList<>();
+
+               private final CompletableFuture<WebSocket<I, O>> 
webSocketFuture = new CompletableFuture<>();
+
+               CompletableFuture<WebSocket<I, O>> getWebSocketFuture() {
+                       return webSocketFuture;
+               }
+
+               public WsResponseHandler(Channel channel, Class<I> 
inboundClass, Class<O> outboundClass, WebSocketListener<I>[] listeners) {
+                       super(inboundClass);
+                       this.channel = channel;
+                       this.listeners.addAll(Arrays.asList(listeners));
+               }
+
+               @Override
+               public void exceptionCaught(ChannelHandlerContext ctx, 
Throwable cause) throws Exception {
+                       LOG.warn("WebSocket exception caught", cause);
+                       webSocketFuture.completeExceptionally(cause);
+               }
+
+               @Override
+               public void userEventTriggered(ChannelHandlerContext ctx, 
Object evt) throws Exception {
+                       if (evt instanceof 
WebSocketClientProtocolHandler.ClientHandshakeStateEvent) {
+                               
WebSocketClientProtocolHandler.ClientHandshakeStateEvent wsevt = 
(WebSocketClientProtocolHandler.ClientHandshakeStateEvent) evt;
+                               switch (wsevt) {
+                                       case HANDSHAKE_ISSUED:
+                                               LOG.debug("WebSocket handshake 
initiated");
+                                               break;
+                                       case HANDSHAKE_COMPLETE:
+                                               LOG.debug("WebSocket handshake 
completed");
+                                               webSocketFuture.complete(this);
+                                               break;
+                               }
+                       }
+                       else {
+                               super.userEventTriggered(ctx, evt);
+                       }
+               }
+
+               @Override
+               protected void channelRead0(ChannelHandlerContext 
channelHandlerContext, I msg) throws Exception {
+                       for (WebSocketListener<I> listener : listeners) {
+                               listener.onEvent(msg);
+                       }
+               }
+
+               @Override
+               public void addListener(WebSocketListener<I> listener) {
+                       listeners.add(listener);
+               }
+
+               @Override
+               public ChannelFuture send(O message) {
+                       try {
+                               return channel.writeAndFlush(message);
+                       }
+                       finally {
+                               ReferenceCountUtil.release(message);
+                       }
+               }
+
+               @Override
+               public ChannelFuture close() {
+                       return channel.close();
+               }
+       }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/AbstractWebSocketHandler.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/AbstractWebSocketHandler.java
new file mode 100644
index 00000000000..3a21d8a77d9
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/AbstractWebSocketHandler.java
@@ -0,0 +1,304 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.rest.handler;
+
+import org.apache.flink.api.common.time.Time;
+import org.apache.flink.runtime.concurrent.FutureUtils;
+import org.apache.flink.runtime.rest.handler.util.HandlerUtils;
+import org.apache.flink.runtime.rest.messages.ErrorResponseBody;
+import org.apache.flink.runtime.rest.messages.MessageParameters;
+import org.apache.flink.runtime.rest.messages.RequestBody;
+import org.apache.flink.runtime.rest.messages.ResponseBody;
+import org.apache.flink.runtime.rest.messages.WebSocketSpecification;
+import org.apache.flink.runtime.webmonitor.RestfulGateway;
+import org.apache.flink.runtime.webmonitor.retriever.GatewayRetriever;
+import org.apache.flink.util.ExceptionUtils;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandler;
+import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext;
+import 
org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandlerAdapter;
+import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpRequest;
+import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponseStatus;
+import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.router.Routed;
+import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
+import org.apache.flink.shaded.netty4.io.netty.util.AttributeKey;
+import org.apache.flink.shaded.netty4.io.netty.util.ReferenceCountUtil;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnull;
+
+import java.util.concurrent.CompletableFuture;
+
+/**
+ * A channel handler for WebSocket resources.
+ *
+ * <p>This handler handles handshaking and ongoing messaging with a WebSocket 
client,
+ * based on a {@link WebSocketSpecification} that describes the REST resource 
location,
+ * parameter type, and message inbound/outbound types.  Messages are 
automatically
+ * encoded from (and decoded to) JSON text.
+ *
+ * <p>Subclasses should override the following methods to extend the 
respective phases.
+ * <ol>
+ *     <li>{@code handshakeInitiated} - occurs upon receipt of a handshake 
request from an HTTP client.  Useful for parameter validation.</li>
+ *     <li>{@code handshakeCompleted} - occurs upon successful completion; 
WebSocket is ready for I/O.</li>
+ *     <li>{@code messageReceived} - occurs when a WebSocket message is 
received on the channel.</li>
+ * </ol>
+ *
+ * <p>The handler supports gateway availability announcements.
+ *
+ * @param <T> The gateway type.
+ * @param <M> The REST parameter type.
+ * @param <I> The inbound message type.
+ * @param <O> The outbound message type.
+ */
+public abstract class AbstractWebSocketHandler<T extends RestfulGateway, M 
extends MessageParameters, I extends RequestBody, O extends ResponseBody> 
extends ChannelInboundHandlerAdapter {
+
+       protected final Logger log = LoggerFactory.getLogger(getClass());
+
+       private final RedirectHandler redirectHandler;
+
+       private final AttributeKey<T> gatewayAttr;
+
+       private final WebSocketSpecification<M, I, O> specification;
+
+       private final ChannelHandler messageCodec;
+
+       private final AttributeKey<M> parametersAttr;
+
+       /**
+        * Creates a new handler.
+        */
+       public AbstractWebSocketHandler(
+               @Nonnull CompletableFuture<String> localAddressFuture,
+               @Nonnull GatewayRetriever<? extends T> leaderRetriever,
+               @Nonnull Time timeout,
+               @Nonnull WebSocketSpecification<M, I, O> specification) {
+               this.redirectHandler = new 
RedirectHandler<>(localAddressFuture, leaderRetriever, timeout);
+               this.gatewayAttr = AttributeKey.valueOf("gateway");
+               this.specification = specification;
+               this.messageCodec = new 
JsonWebSocketMessageCodec<>(specification.getClientClass(), 
specification.getServerClass());
+               this.parametersAttr = AttributeKey.valueOf("parameters");
+       }
+
+       /**
+        * Sets the gateway associated with the channel.
+        */
+       private void setGateway(ChannelHandlerContext ctx, T gateway) {
+               ctx.attr(gatewayAttr).set(gateway);
+       }
+
+       /**
+        * Returns the gateway associated with the channel.
+        */
+       public T getGateway(ChannelHandlerContext ctx) {
+               T t = ctx.attr(gatewayAttr).get();
+               Preconditions.checkState(t != null, "Gateway is not 
available.");
+               return t;
+       }
+
+       /**
+        * Sets the resource parameters associated with the channel.
+        *
+        * <p>The parameters are established by the WebSocket handshake request.
+        */
+       private void setMessageParameters(ChannelHandlerContext ctx, M 
parameters) {
+               ctx.attr(parametersAttr).set(parameters);
+       }
+
+       /**
+        * Returns the resource parameters associated with the channel.
+        */
+       public M getMessageParameters(ChannelHandlerContext ctx) {
+               M o = ctx.attr(parametersAttr).get();
+               Preconditions.checkState(o != null, "Message parameters are not 
available.");
+               return o;
+       }
+
+       @Override
+       public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
+               if (ctx.pipeline().get(RedirectHandler.class.getName()) == 
null) {
+                       ctx.pipeline().addBefore(ctx.name(), 
RedirectHandler.class.getName(), redirectHandler);
+               }
+       }
+
+       @Override
+       @SuppressWarnings("unchecked")
+       public void userEventTriggered(ChannelHandlerContext ctx, Object evt) 
throws Exception {
+               if (evt instanceof RedirectHandler.GatewayRetrieved) {
+                       T gateway = ((RedirectHandler.GatewayRetrieved<T>) 
evt).getGateway();
+                       setGateway(ctx, gateway);
+                       log.debug("Gateway retrieved: {}", gateway);
+               }
+               else if (evt instanceof 
WebSocketServerProtocolHandler.ServerHandshakeStateEvent) {
+                       
WebSocketServerProtocolHandler.ServerHandshakeStateEvent handshakeEvent = 
(WebSocketServerProtocolHandler.ServerHandshakeStateEvent) evt;
+                       if (handshakeEvent == 
WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE) {
+                               log.debug("Handshake completed with client IP: 
{}", ctx.channel().remoteAddress());
+                               M parameters = getMessageParameters(ctx);
+                               handshakeCompleted(ctx, parameters);
+                       }
+               }
+
+               super.userEventTriggered(ctx, evt);
+       }
+
+       @Override
+       @SuppressWarnings("unchecked")
+       public void channelRead(ChannelHandlerContext ctx, Object o) throws 
Exception {
+               if 
(specification.getClientClass().isAssignableFrom(o.getClass())) {
+                       // process an inbound message
+                       M parameters = getMessageParameters(ctx);
+                       try {
+                               messageReceived(ctx, parameters, (I) o);
+                       }
+                       finally {
+                               ReferenceCountUtil.release(o);
+                       }
+                       return;
+               }
+
+               if (!(o instanceof Routed)) {
+                       // a foreign message
+                       ctx.fireChannelRead(o);
+                       return;
+               }
+
+               // process an inbound HTTP request
+               Routed request = (Routed) o;
+
+               // parse the REST request parameters
+               M messageParameters = 
specification.getUnresolvedMessageParameters();
+               try {
+                       
messageParameters.resolveParameters(request.pathParams(), 
request.queryParams());
+                       if (!messageParameters.isResolved()) {
+                               throw new IllegalArgumentException("One or more 
mandatory parameters is missing.");
+                       }
+               }
+               catch (IllegalArgumentException e) {
+                       HandlerUtils.sendErrorResponse(
+                               ctx,
+                               request.request(),
+                               new ErrorResponseBody(String.format("Bad 
request, could not parse parameters: %s", e.getMessage())),
+                               HttpResponseStatus.BAD_REQUEST);
+                       ReferenceCountUtil.release(request);
+                       return;
+               }
+               setMessageParameters(ctx, messageParameters);
+
+               // validate the inbound handshake request with the subclass
+               CompletableFuture<Void> handshakeReady;
+               try {
+                       handshakeReady = handshakeInitiated(ctx, 
messageParameters);
+               } catch (Exception e) {
+                       handshakeReady = FutureUtils.completedExceptionally(e);
+               }
+               handshakeReady.whenCompleteAsync((Void v, Throwable throwable) 
-> {
+                       try {
+                               if (throwable != null) {
+                                       Throwable error = 
ExceptionUtils.stripCompletionException(throwable);
+                                       if (error instanceof 
RestHandlerException) {
+                                               final RestHandlerException rhe 
= (RestHandlerException) error;
+                                               log.error("Exception occurred 
in REST handler.", error);
+                                               
HandlerUtils.sendErrorResponse(ctx, request.request(), new 
ErrorResponseBody(rhe.getMessage()), rhe.getHttpResponseStatus());
+                                       } else {
+                                               log.error("Implementation 
error: Unhandled exception.", error);
+                                               
HandlerUtils.sendErrorResponse(ctx, request.request(), new 
ErrorResponseBody("Internal server error."), 
HttpResponseStatus.INTERNAL_SERVER_ERROR);
+                                       }
+                               } else {
+                                       upgradeToWebSocket(ctx, request);
+                               }
+                       }
+                       finally {
+                               ReferenceCountUtil.release(request);
+                       }
+               }, ctx.executor());
+       }
+
+       private void upgradeToWebSocket(final ChannelHandlerContext ctx, Routed 
msg) {
+
+               // store the context of the handler that precedes the current 
handler,
+               // to use that context later to forward the HTTP request to the 
WebSocket protocol handler
+               String before = 
ctx.pipeline().names().get(ctx.pipeline().names().indexOf(ctx.name()) - 1);
+               ChannelHandlerContext beforeCtx = 
ctx.pipeline().context(before);
+
+               // inject the websocket protocol handler into this channel, to 
be active
+               // until the channel is closed.  note that the handshake may or 
may not complete synchronously.
+               ctx.pipeline().addBefore(ctx.name(), 
WebSocketServerProtocolHandler.class.getName(),
+                       new WebSocketServerProtocolHandler(msg.path(), 
specification.getSubprotocol()));
+
+               // inject the message codec
+               ctx.pipeline().addBefore(ctx.name(), 
messageCodec.getClass().getName(), messageCodec);
+
+               log.debug("Upgraded channel with WS protocol handler and 
message codec.");
+
+               // forward the message to the installed protocol handler to 
initiate handshaking
+               HttpRequest request = msg.request();
+               ReferenceCountUtil.retain(request);
+               beforeCtx.fireChannelRead(request);
+       }
+
+       @Override
+       public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) 
throws Exception {
+               log.error("WebSocket channel error; closing the channel.", 
cause);
+               ctx.close();
+       }
+
+       /**
+        * Handles a client handshake request to open a WebSocket resource.  
Returns a {@link CompletableFuture} to complete handshaking.
+        *
+        * <p>Implementations may decide whether to throw {@link 
RestHandlerException}s or fail the returned
+        * {@link CompletableFuture} with a {@link RestHandlerException}.
+        *
+        * <p>Failing the future with another exception type or throwing 
unchecked exceptions is regarded as an
+        * implementation error as it does not allow us to provide a meaningful 
HTTP status code. In this case a
+        * {@link HttpResponseStatus#INTERNAL_SERVER_ERROR} will be returned.
+        *
+        * @param parameters the REST parameters
+
+        * @return future indicating completion of handshake pre-processing.
+        * @throws RestHandlerException to produce a pre-formatted HTTP error 
response.
+        */
+       protected CompletableFuture<Void> 
handshakeInitiated(ChannelHandlerContext ctx, M parameters) throws Exception {
+               return CompletableFuture.completedFuture(null);
+       }
+
+       /**
+        * Invoked when the current channel has completed the handshaking to 
establish a WebSocket connection.
+        *
+        * @param ctx the channel handler context
+        * @param parameters the REST parameters
+        * @throws Exception if processing failed.
+        */
+       protected void handshakeCompleted(ChannelHandlerContext ctx, M 
parameters) throws Exception {
+       }
+
+       /**
+        * Invoked when the current channel has received a WebSocket message.
+        *
+        * <p>The message object is automatically released after this method is 
called.
+        *
+        * @param ctx the channel handler context
+        * @param parameters the REST parameters
+        * @param msg the message received
+        * @throws Exception if the message could not be processed.
+        */
+       protected abstract void messageReceived(ChannelHandlerContext ctx, M 
parameters, I msg) throws Exception;
+}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/HandlerRequest.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/HandlerRequest.java
index aacf0a22a16..f92282956b2 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/HandlerRequest.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/HandlerRequest.java
@@ -25,10 +25,10 @@
 import org.apache.flink.util.Preconditions;
 
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.StringJoiner;
+import java.util.function.Function;
+import java.util.stream.Collectors;
 
 /**
  * Simple container for the request to a handler, that contains the {@link 
RequestBody} and path/query parameters.
@@ -39,52 +39,29 @@
 public class HandlerRequest<R extends RequestBody, M extends 
MessageParameters> {
 
        private final R requestBody;
-       private final Map<Class<? extends MessagePathParameter<?>>, 
MessagePathParameter<?>> pathParameters = new HashMap<>(2);
-       private final Map<Class<? extends MessageQueryParameter<?>>, 
MessageQueryParameter<?>> queryParameters = new HashMap<>(2);
+       private final Map<Class<? extends MessagePathParameter<?>>, 
MessagePathParameter<?>> pathParameters;
+       private final Map<Class<? extends MessageQueryParameter<?>>, 
MessageQueryParameter<?>> queryParameters;
 
        public HandlerRequest(R requestBody, M messageParameters) throws 
HandlerRequestException {
                this(requestBody, messageParameters, Collections.emptyMap(), 
Collections.emptyMap());
        }
 
+       @SuppressWarnings("unchecked")
        public HandlerRequest(R requestBody, M messageParameters, Map<String, 
String> receivedPathParameters, Map<String, List<String>> 
receivedQueryParameters) throws HandlerRequestException {
                this.requestBody = Preconditions.checkNotNull(requestBody);
                Preconditions.checkNotNull(messageParameters);
                Preconditions.checkNotNull(receivedQueryParameters);
                Preconditions.checkNotNull(receivedPathParameters);
 
-               for (MessagePathParameter<?> pathParameter : 
messageParameters.getPathParameters()) {
-                       String value = 
receivedPathParameters.get(pathParameter.getKey());
-                       if (value != null) {
-                               try {
-                                       pathParameter.resolveFromString(value);
-                               } catch (Exception e) {
-                                       throw new 
HandlerRequestException("Cannot resolve path parameter (" + 
pathParameter.getKey() + ") from value \"" + value + "\".");
-                               }
-
-                               @SuppressWarnings("unchecked")
-                               Class<? extends MessagePathParameter<?>> clazz 
= (Class<? extends MessagePathParameter<?>>) pathParameter.getClass();
-                               pathParameters.put(clazz, pathParameter);
-                       }
+               try {
+                       
messageParameters.resolveParameters(receivedPathParameters, 
receivedQueryParameters);
                }
-
-               for (MessageQueryParameter<?> queryParameter : 
messageParameters.getQueryParameters()) {
-                       List<String> values = 
receivedQueryParameters.get(queryParameter.getKey());
-                       if (values != null && !values.isEmpty()) {
-                               StringJoiner joiner = new StringJoiner(",");
-                               values.forEach(joiner::add);
-
-                               try {
-                                       
queryParameter.resolveFromString(joiner.toString());
-                               } catch (Exception e) {
-                                       throw new 
HandlerRequestException("Cannot resolve query parameter (" + 
queryParameter.getKey() + ") from value \"" + joiner + "\".");
-                               }
-
-                               @SuppressWarnings("unchecked")
-                               Class<? extends MessageQueryParameter<?>> clazz 
= (Class<? extends MessageQueryParameter<?>>) queryParameter.getClass();
-                               queryParameters.put(clazz, queryParameter);
-                       }
-
+               catch (IllegalArgumentException e) {
+                       throw new HandlerRequestException("Unable to resolve 
the request parameters: " + e.getMessage());
                }
+
+               pathParameters = 
messageParameters.getPathParameters().stream().collect(Collectors.toMap(p -> 
(Class<? extends MessagePathParameter<?>>) p.getClass(), Function.identity()));
+               queryParameters = 
messageParameters.getQueryParameters().stream().collect(Collectors.toMap(p -> 
(Class<? extends MessageQueryParameter<?>>) p.getClass(), Function.identity()));
        }
 
        /**
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/JsonWebSocketMessageCodec.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/JsonWebSocketMessageCodec.java
new file mode 100644
index 00000000000..48127bd88fe
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/JsonWebSocketMessageCodec.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.rest.handler;
+
+import org.apache.flink.runtime.rest.util.RestMapperUtils;
+
+import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandler;
+import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext;
+import 
org.apache.flink.shaded.netty4.io.netty.channel.CombinedChannelDuplexHandler;
+import org.apache.flink.shaded.netty4.io.netty.handler.codec.DecoderException;
+import org.apache.flink.shaded.netty4.io.netty.handler.codec.EncoderException;
+import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.MessageToMessageDecoder;
+import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.MessageToMessageEncoder;
+import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.ObjectReader;
+import com.fasterxml.jackson.databind.ObjectWriter;
+
+import java.io.StringReader;
+import java.io.StringWriter;
+import java.util.List;
+
+/**
+ * A codec for JSON-encoded WebSocket messages.
+ *
+ * @param <I> the inbound message type (converted from JSON).
+ * @param <O> the outbound message type (converted to JSON).
+ */
+@ChannelHandler.Sharable
+public class JsonWebSocketMessageCodec<I, O> extends 
CombinedChannelDuplexHandler<JsonWebSocketMessageCodec.WebSocketMessageDecoder<I>,
 JsonWebSocketMessageCodec.WebSocketMessageEncoder<O>> {
+
+       private static final ObjectMapper mapper = 
RestMapperUtils.getStrictObjectMapper();
+
+       public JsonWebSocketMessageCodec(Class<I> inboundMessageClass, Class<O> 
outboundMessageClass) {
+               super(new WebSocketMessageDecoder<>(inboundMessageClass), new 
WebSocketMessageEncoder<>(outboundMessageClass));
+       }
+
+       @ChannelHandler.Sharable
+       static class WebSocketMessageDecoder<I> extends 
MessageToMessageDecoder<TextWebSocketFrame> {
+
+               private final ObjectReader reader;
+
+               public WebSocketMessageDecoder(Class<I> messageClass) {
+                       reader = mapper.readerFor(messageClass);
+               }
+
+               @Override
+               protected void decode(ChannelHandlerContext 
channelHandlerContext, TextWebSocketFrame frame, List<Object> list) throws 
Exception {
+                       try {
+                               try (StringReader sr = new 
StringReader(frame.text())) {
+                                       I i = reader.readValue(sr);
+                                       list.add(i);
+                               }
+                       }
+                       catch (Exception e) {
+                               throw new DecoderException("Unable to decode 
the WebSocket frame as a JSON object", e);
+                       }
+               }
+       }
+
+       @ChannelHandler.Sharable
+       static class WebSocketMessageEncoder<O> extends 
MessageToMessageEncoder<O> {
+
+               private final ObjectWriter writer;
+
+               protected WebSocketMessageEncoder(Class<O> messageClass) {
+                       super(messageClass);
+                       writer = mapper.writerFor(messageClass);
+               }
+
+               @Override
+               protected void encode(ChannelHandlerContext ctx, O o, 
List<Object> list) throws Exception {
+                       try {
+                               StringWriter sw = new StringWriter();
+                               writer.writeValue(sw, o);
+                               list.add(new TextWebSocketFrame(sw.toString()));
+                       }
+                       catch (Exception e) {
+                               throw new EncoderException("Unable to encode 
the JSON object as a WebSocket frame", e);
+                       }
+               }
+       }
+}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/RedirectHandler.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/RedirectHandler.java
index 40b67767737..07a79e11902 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/RedirectHandler.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/RedirectHandler.java
@@ -49,10 +49,16 @@
  * {@link SimpleChannelInboundHandler} which encapsulates the redirection 
logic for the
  * REST endpoints.
  *
+ * <p>This handler supports two modes of use.
+ * <ol>
+ *     <li>Inheritance - subclasses override {@code respondAsLeader}.</li>
+ *     <li>Composition - upstream handlers receive a user event containing the 
gateway, see {@link GatewayRetrieved}.</li>
+ * </ol>
+ *
  * @param <T> type of the leader to retrieve
  */
 @ChannelHandler.Sharable
-public abstract class RedirectHandler<T extends RestfulGateway> extends 
SimpleChannelInboundHandler<Routed> {
+public class RedirectHandler<T extends RestfulGateway> extends 
SimpleChannelInboundHandler<Routed> {
 
        protected final Logger logger = LoggerFactory.getLogger(getClass());
 
@@ -64,7 +70,7 @@
 
        private String localAddress;
 
-       protected RedirectHandler(
+       public RedirectHandler(
                        @Nonnull CompletableFuture<String> localAddressFuture,
                        @Nonnull GatewayRetriever<? extends T> leaderRetriever,
                        @Nonnull Time timeout) {
@@ -172,5 +178,39 @@ protected void channelRead0(
                }
        }
 
-       protected abstract void respondAsLeader(ChannelHandlerContext 
channelHandlerContext, Routed routed, T gateway) throws Exception;
+       /**
+        * Responds to an HTTP request in combination with the leader gateway.
+        *
+        * <p>The default behavior is to announce the leader gateway with a 
user event,
+        * and then to forward the HTTP request to the next handler.
+        *
+        * @param ctx the channel handler context
+        * @param routed the HTTP request
+        * @param gateway the leader gateway
+        */
+       protected void respondAsLeader(ChannelHandlerContext ctx, Routed 
routed, T gateway) throws Exception {
+
+               // announce the gateway to upstream handlers
+               ctx.fireUserEventTriggered(new GatewayRetrieved<>(gateway));
+
+               // propagate the HTTP request
+               ReferenceCountUtil.retain(routed);
+               ctx.fireChannelRead(routed);
+       }
+
+       /**
+        * A gateway retrieval event.
+        * @param <T> the gateway type
+        */
+       public static class GatewayRetrieved<T> {
+               private final T gateway;
+
+               public GatewayRetrieved(T gateway) {
+                       this.gateway = gateway;
+               }
+
+               public T getGateway() {
+                       return gateway;
+               }
+       }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/MessageParameters.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/MessageParameters.java
index b19b12e7206..b91c4510e83 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/MessageParameters.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/MessageParameters.java
@@ -21,9 +21,15 @@
 import org.apache.flink.util.Preconditions;
 
 import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.StringJoiner;
 
 /**
  * This class defines the path/query {@link MessageParameter}s that can be 
used for a request.
+ *
+ * <p>An instance of {@link MessageParameters} is mutable.
  */
 public abstract class MessageParameters {
 
@@ -61,7 +67,7 @@ public final boolean isResolved() {
         * <p>Unresolved optional parameters will be ignored.
         *
         * @param genericUrl URL to resolve
-        * @param parameters message parameters parameters
+        * @param parameters message parameters
         * @return resolved url, e.g "/jobs/1234?state=running"
         * @throws IllegalStateException if any mandatory parameter was not 
resolved
         */
@@ -100,4 +106,74 @@ public static String resolveUrl(String genericUrl, 
MessageParameters parameters)
 
                return path.toString();
        }
+
+       /**
+        * Resolves the message parameters using the given path and query 
parameter values.
+        *
+        * <p>This method updates the state of the parameters of this message.
+        *
+        * @param receivedPathParameters the received path parameters.
+        * @param receivedQueryParameters the received query parameters.
+        * @throws IllegalArgumentException if a parameter value cannot be 
processed.
+        */
+       public void resolveParameters(Map<String, String> 
receivedPathParameters, Map<String, List<String>> receivedQueryParameters)  {
+               for (MessagePathParameter<?> pathParameter : 
getPathParameters()) {
+                       String value = 
receivedPathParameters.get(pathParameter.getKey());
+                       if (value != null) {
+                               try {
+                                       pathParameter.resolveFromString(value);
+                               } catch (Exception e) {
+                                       throw new 
IllegalArgumentException("Cannot resolve path parameter (" + 
pathParameter.getKey() + ") from value \"" + value + "\".");
+                               }
+                       }
+               }
+
+               for (MessageQueryParameter<?> queryParameter : 
getQueryParameters()) {
+                       List<String> values = 
receivedQueryParameters.get(queryParameter.getKey());
+                       if (values != null && !values.isEmpty()) {
+                               StringJoiner joiner = new StringJoiner(",");
+                               values.forEach(joiner::add);
+
+                               try {
+                                       
queryParameter.resolveFromString(joiner.toString());
+                               } catch (Exception e) {
+                                       throw new 
IllegalArgumentException("Cannot resolve query parameter (" + 
queryParameter.getKey() + ") from value \"" + joiner + "\".");
+                               }
+                       }
+               }
+       }
+
+
+       /**
+        * Returns the value of the {@link MessagePathParameter} for the given 
class.
+        *
+        * @param parameterClass class of the parameter
+        * @param <X>            the value type that the parameter contains
+        * @param <PP>           type of the path parameter
+        * @return path parameter value for the given class
+        * @throws IllegalStateException if no value is defined for the given 
parameter class
+        */
+       @SuppressWarnings("unchecked")
+       public <X, PP extends MessagePathParameter<X>> X 
getPathParameter(Class<PP> parameterClass) {
+               return getPathParameters().stream()
+                       .filter(p -> parameterClass.equals(p.getClass()))
+                       .map(p -> ((PP) p).getValue())
+                       .findAny().orElseThrow(() -> new 
IllegalStateException("No parameter could be found for the given class."));
+       }
+
+       /**
+        * Returns the value of the {@link MessageQueryParameter} for the given 
class.
+        *
+        * @param parameterClass class of the parameter
+        * @param <X>            the value type that the parameter contains
+        * @param <QP>           type of the query parameter
+        * @return query parameter value for the given class, or an empty list 
if no parameter value exists for the given class
+        */
+       @SuppressWarnings("unchecked")
+       public <X, QP extends MessageQueryParameter<X>> List<X> 
getQueryParameter(Class<QP> parameterClass) {
+               return getQueryParameters().stream()
+                       .filter(p -> parameterClass.equals(p.getClass()))
+                       .map(p -> ((QP) p).getValue())
+                       .findAny().orElse(Collections.emptyList());
+       }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/RequestBody.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/RequestBody.java
index ca55b17532b..7e2ef568448 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/RequestBody.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/RequestBody.java
@@ -19,7 +19,7 @@
 package org.apache.flink.runtime.rest.messages;
 
 /**
- * Marker interface for all requests of the REST API. This class represents 
the http body of a request.
+ * Marker interface for all requests of the REST API. This class represents 
the body of an HTTP request or WebSocket message.
  *
  * <p>Subclass instances are converted to JSON using jackson-databind. 
Subclasses must have a constructor that accepts
  * all fields of the JSON request, that should be annotated with {@code 
@JsonCreator}.
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/ResponseBody.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/ResponseBody.java
index d4e94d1d6ab..6d183903637 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/ResponseBody.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/ResponseBody.java
@@ -19,7 +19,7 @@
 package org.apache.flink.runtime.rest.messages;
 
 /**
- * Marker interface for all responses of the REST API. This class represents 
the http body of a response.
+ * Marker interface for all responses of the REST API. This class represents 
the body of an HTTP response or WebSocket message.
  *
  * <p>Subclass instances are converted to JSON using jackson-databind. 
Subclasses must have a constructor that accepts
  * all fields of the JSON response, that should be annotated with {@code 
@JsonCreator}.
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/WebSocketSpecification.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/WebSocketSpecification.java
new file mode 100644
index 00000000000..38ca0e86743
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/WebSocketSpecification.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.rest.messages;
+
+import org.apache.flink.runtime.rest.HttpMethodWrapper;
+import org.apache.flink.runtime.rest.handler.RestHandlerSpecification;
+
+/**
+ * Extended REST handler specification with websocket information.
+ *
+ * <p>Implementations must be state-less.
+ *
+ * @param <M> message parameters type
+ * @param <I> inbound message type
+ * @param <O> outbound message type
+ */
+public interface WebSocketSpecification<M extends MessageParameters, I extends 
RequestBody, O extends ResponseBody> extends RestHandlerSpecification {
+
+       @Override
+       default HttpMethodWrapper getHttpMethod() {
+               return HttpMethodWrapper.GET;
+       }
+
+       /**
+        * Returns the WebSocket subprotocol associated with the REST resource.
+        */
+       String getSubprotocol();
+
+       /**
+        * Returns the base class of client-to-server messages.
+        *
+        * @return class of the message
+        */
+       Class<I> getClientClass();
+
+       /**
+        * Returns the base class of server-to-client messages.
+        *
+        * @return class of the message
+        */
+       Class<O> getServerClass();
+
+       /**
+        * Returns a new {@link MessageParameters} object.
+        *
+        * @return new message parameters object
+        */
+       M getUnresolvedMessageParameters();
+}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/websocket/KeyedChannelRouter.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/websocket/KeyedChannelRouter.java
new file mode 100644
index 00000000000..4fecafb4100
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/websocket/KeyedChannelRouter.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.rest.websocket;
+
+import org.apache.flink.shaded.netty4.io.netty.channel.Channel;
+import org.apache.flink.shaded.netty4.io.netty.channel.group.ChannelGroup;
+import 
org.apache.flink.shaded.netty4.io.netty.channel.group.ChannelGroupFuture;
+import org.apache.flink.shaded.netty4.io.netty.util.AttributeKey;
+
+import javax.annotation.Nonnull;
+
+import java.util.Objects;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/**
+ * Routes messages to channels based on a routing key.
+ * @param <K> the key type.
+ */
+public class KeyedChannelRouter<K> {
+
+       private final AttributeKey<K> attributeKey;
+
+       private final ChannelGroup channels;
+
+       public KeyedChannelRouter(AttributeKey<K> attributeKey, ChannelGroup 
channelGroup) {
+               this.attributeKey = checkNotNull(attributeKey);
+               this.channels = checkNotNull(channelGroup);
+       }
+
+       /**
+        * Registers a channel to receive messages for a given routing key.
+        *
+        * @param channel the channel to register.
+        */
+       public void register(@Nonnull Channel channel, @Nonnull K routingKey) {
+               channel.attr(attributeKey).set(routingKey);
+               channels.add(channel);
+       }
+
+       /**
+        * Unregisters a channel.
+        *
+        * @param channel the channel to unregister.
+        */
+       public void unregister(@Nonnull Channel channel) {
+               channels.remove(channel);
+       }
+
+       /**
+        * Writes and flushes an object to select channels based on a routing 
key.
+        *
+        * @param routingKey the key to select the target channel(s).
+        * @param o the object to write and flush.
+        */
+       public ChannelGroupFuture write(@Nonnull K routingKey, @Nonnull Object 
o) {
+               return channels.write(o, channel -> isMatch(channel, 
routingKey));
+       }
+
+       /**
+        * Writes and flushes an object to select channels based on a routing 
key.
+        *
+        * @param routingKey the key to select the target channel(s).
+        * @param o the object to write and flush.
+        */
+       public ChannelGroupFuture writeAndFlush(@Nonnull K routingKey, @Nonnull 
Object o) {
+               return channels.writeAndFlush(o, channel -> isMatch(channel, 
routingKey));
+       }
+
+       /**
+        * Closes all channels associated with the given routing key.
+        */
+       public ChannelGroupFuture close(@Nonnull K routingKey) {
+               return channels.close(channel -> isMatch(channel, routingKey));
+       }
+
+       private boolean isMatch(Channel channel, K routingKey) {
+               return Objects.equals(routingKey, 
channel.attr(attributeKey).get());
+       }
+}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/websocket/WebSocket.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/websocket/WebSocket.java
new file mode 100644
index 00000000000..77d53ef4427
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/websocket/WebSocket.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.rest.websocket;
+
+import org.apache.flink.runtime.rest.messages.RequestBody;
+import org.apache.flink.runtime.rest.messages.ResponseBody;
+
+import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFuture;
+
+/**
+ * A WebSocket for sending and receiving messages.
+ *
+ * @param <I> type of the server-to-client messages.
+ * @param <O> type of the client-to-server messages.
+ */
+public interface WebSocket<I extends ResponseBody, O extends RequestBody> {
+
+       /**
+        * Adds a listener for websocket messages.
+        */
+       void addListener(WebSocketListener<I> listener);
+
+       /**
+        * Sends a message.
+        */
+       ChannelFuture send(O message);
+
+       /**
+        * Closes the websocket.
+        */
+       ChannelFuture close();
+}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/websocket/WebSocketListener.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/websocket/WebSocketListener.java
new file mode 100644
index 00000000000..33410fc4fcc
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/websocket/WebSocketListener.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.rest.websocket;
+
+import org.apache.flink.runtime.rest.messages.ResponseBody;
+import org.apache.flink.runtime.util.event.EventListener;
+
+/**
+ * A listener for WebSocket messages.
+ *
+ * @param <T> type of the server-to-client messages.
+ */
+public interface WebSocketListener<T extends ResponseBody> extends 
EventListener<T> { }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/rest/RestEndpointITCase.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/rest/RestEndpointITCase.java
index be5985f3a2c..0e4218c31ff 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/rest/RestEndpointITCase.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/rest/RestEndpointITCase.java
@@ -23,6 +23,7 @@
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.rest.handler.AbstractRestHandler;
+import org.apache.flink.runtime.rest.handler.AbstractWebSocketHandler;
 import org.apache.flink.runtime.rest.handler.HandlerRequest;
 import org.apache.flink.runtime.rest.handler.RestHandlerException;
 import org.apache.flink.runtime.rest.handler.RestHandlerSpecification;
@@ -33,7 +34,10 @@
 import org.apache.flink.runtime.rest.messages.MessageQueryParameter;
 import org.apache.flink.runtime.rest.messages.RequestBody;
 import org.apache.flink.runtime.rest.messages.ResponseBody;
+import org.apache.flink.runtime.rest.messages.WebSocketSpecification;
 import org.apache.flink.runtime.rest.util.RestClientException;
+import org.apache.flink.runtime.rest.websocket.KeyedChannelRouter;
+import org.apache.flink.runtime.rest.websocket.WebSocket;
 import org.apache.flink.runtime.rpc.RpcUtils;
 import org.apache.flink.runtime.testingUtils.TestingUtils;
 import org.apache.flink.runtime.webmonitor.RestfulGateway;
@@ -42,8 +46,14 @@
 import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.TestLogger;
 
+import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFutureListener;
+import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext;
 import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandler;
+import org.apache.flink.shaded.netty4.io.netty.channel.group.ChannelGroup;
+import 
org.apache.flink.shaded.netty4.io.netty.channel.group.DefaultChannelGroup;
 import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponseStatus;
+import org.apache.flink.shaded.netty4.io.netty.util.AttributeKey;
+import 
org.apache.flink.shaded.netty4.io.netty.util.concurrent.GlobalEventExecutor;
 
 import com.fasterxml.jackson.annotation.JsonCreator;
 import com.fasterxml.jackson.annotation.JsonProperty;
@@ -55,11 +65,14 @@
 import javax.annotation.Nonnull;
 
 import java.net.InetSocketAddress;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Optional;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutionException;
+import java.util.concurrent.LinkedBlockingQueue;
 
 import static org.mockito.Matchers.any;
 import static org.mockito.Mockito.mock;
@@ -78,6 +91,8 @@
        private RestServerEndpoint serverEndpoint;
        private RestClient clientEndpoint;
 
+       private TestEventProvider eventProvider;
+
        @Before
        public void setup() throws Exception {
                Configuration config = new Configuration();
@@ -91,12 +106,22 @@ public void setup() throws Exception {
                GatewayRetriever<RestfulGateway> mockGatewayRetriever = 
mock(GatewayRetriever.class);
                
when(mockGatewayRetriever.getNow()).thenReturn(Optional.of(mockRestfulGateway));
 
+               // a REST operation
                TestHandler testHandler = new TestHandler(
                        CompletableFuture.completedFuture(restAddress),
                        mockGatewayRetriever,
                        RpcUtils.INF_TIMEOUT);
 
-               serverEndpoint = new TestRestServerEndpoint(serverConfig, 
testHandler);
+               // a WebSocket operation
+               ChannelGroup eventChannelGroup = new 
DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
+               eventProvider = new TestEventProvider(eventChannelGroup);
+               TestWebSocketOperation.WsRestHandler testWebSocketHandler = new 
TestWebSocketOperation.WsRestHandler(
+                       CompletableFuture.completedFuture(restAddress),
+                       mockGatewayRetriever,
+                       eventProvider,
+                       RpcUtils.INF_TIMEOUT);
+
+               serverEndpoint = new TestRestServerEndpoint(serverConfig, 
testHandler, testWebSocketHandler);
                clientEndpoint = new TestRestClient(clientConfig);
 
                serverEndpoint.start();
@@ -194,19 +219,74 @@ public void testBadHandlerRequest() throws Exception {
                }
        }
 
+       /**
+        * Tests that a web socket operation works end-to-end.
+        */
+       @Test
+       public void testWebSocketEndToEnd() throws Exception {
+               TestWebSocketOperation.WsParameters parameters = new 
TestWebSocketOperation.WsParameters();
+               parameters.jobIDPathParameter.resolve(PATH_JOB_ID);
+
+               final LinkedBlockingQueue<TestMessage> messageQueue = new 
LinkedBlockingQueue<>();
+               final InetSocketAddress serverAddress = 
serverEndpoint.getServerAddress();
+
+               // open a websocket connection with a listener that simply 
enqueues incoming messages
+               CompletableFuture<WebSocket<TestMessage, TestMessage>> response 
= clientEndpoint.sendWebSocketRequest(
+                       serverAddress.getHostName(),
+                       serverAddress.getPort(),
+                       new TestWebSocketOperation.WsHeaders(),
+                       parameters,
+                       (event) -> {
+                               try {
+                                       messageQueue.put(event);
+                               }
+                               catch (Exception e) {
+                                       throw new RuntimeException(e);
+                               }
+                       });
+
+               // wait for the connection to be established
+               WebSocket<TestMessage, TestMessage> webSocket = response.get();
+               try {
+                       // wait for the server to register the channel (happens 
asynchronously after handshake complete)
+                       TestWebSocketOperation.WsRestHandler.LATCH.await();
+
+                       // play ping-pong (server goes first)
+                       eventProvider.writeAndFlush(PATH_JOB_ID, new 
TestMessage(42));
+                       TestMessage received = messageQueue.take();
+                       Assert.assertEquals(42, received.sequenceNumber);
+                       webSocket.send(received.incr());
+                       received = messageQueue.take();
+                       Assert.assertEquals(44, received.sequenceNumber);
+               }
+               finally {
+                       webSocket.close();
+               }
+       }
+
        private static class TestRestServerEndpoint extends RestServerEndpoint {
 
                private final TestHandler testHandler;
+               private final TestWebSocketOperation.WsRestHandler 
testWebSocketHandler;
 
-               TestRestServerEndpoint(RestServerEndpointConfiguration 
configuration, TestHandler testHandler) {
+               TestRestServerEndpoint(RestServerEndpointConfiguration 
configuration, TestHandler testHandler, TestWebSocketOperation.WsRestHandler 
testWebSocketHandler) {
                        super(configuration);
 
                        this.testHandler = 
Preconditions.checkNotNull(testHandler);
+                       this.testWebSocketHandler = 
Preconditions.checkNotNull(testWebSocketHandler);
                }
 
                @Override
                protected Collection<Tuple2<RestHandlerSpecification, 
ChannelInboundHandler>> initializeHandlers(CompletableFuture<String> 
restAddressFuture) {
-                       return Collections.singleton(Tuple2.of(new 
TestHeaders(), testHandler));
+                       return Arrays.asList(
+                               Tuple2.of(new TestHeaders(), testHandler),
+                               Tuple2.of(new 
TestWebSocketOperation.WsHeaders(), testWebSocketHandler));
+               }
+       }
+
+       private static class TestEventProvider extends 
KeyedChannelRouter<JobID> {
+               public TestEventProvider(ChannelGroup channelGroup) {
+                       super(AttributeKey.valueOf("jobID"), channelGroup);
                }
        }
 
@@ -325,6 +405,88 @@ public TestParameters getUnresolvedMessageParameters() {
                }
        }
 
+       private static class TestWebSocketOperation {
+
+               private static class WsParameters extends MessageParameters {
+                       private final JobIDPathParameter jobIDPathParameter = 
new JobIDPathParameter();
+
+                       @Override
+                       public Collection<MessagePathParameter<?>> 
getPathParameters() {
+                               return 
Collections.singleton(jobIDPathParameter);
+                       }
+
+                       @Override
+                       public Collection<MessageQueryParameter<?>> 
getQueryParameters() {
+                               return Collections.emptyList();
+                       }
+               }
+
+               static class WsHeaders implements 
WebSocketSpecification<WsParameters, TestMessage, TestMessage> {
+
+                       @Override
+                       public String getTargetRestEndpointURL() {
+                               return "/test/:jobid/subscribe";
+                       }
+
+                       @Override
+                       public String getSubprotocol() {
+                               return "test";
+                       }
+
+                       @Override
+                       public Class<TestMessage> getServerClass() {
+                               return TestMessage.class;
+                       }
+
+                       @Override
+                       public Class<TestMessage> getClientClass() {
+                               return TestMessage.class;
+                       }
+
+                       @Override
+                       public WsParameters getUnresolvedMessageParameters() {
+                               return new WsParameters();
+                       }
+               }
+
+               static class WsRestHandler extends 
AbstractWebSocketHandler<RestfulGateway, WsParameters, TestMessage, 
TestMessage> {
+
+                       public static final CountDownLatch LATCH = new 
CountDownLatch(1);
+
+                       private final TestEventProvider eventProvider;
+
+                       WsRestHandler(
+                               CompletableFuture<String> localAddressFuture,
+                               GatewayRetriever<RestfulGateway> 
leaderRetriever,
+                               TestEventProvider eventProvider,
+                               Time timeout) {
+                               super(localAddressFuture, leaderRetriever, 
timeout, new WsHeaders());
+                               this.eventProvider = eventProvider;
+                       }
+
+                       @Override
+                       protected CompletableFuture<Void> 
handshakeInitiated(ChannelHandlerContext ctx, WsParameters parameters) throws 
Exception {
+                               // validate request before completing the 
handshake
+                               JobID jobID = 
parameters.getPathParameter(JobIDPathParameter.class);
+                               Assert.assertEquals(PATH_JOB_ID, jobID);
+                               return CompletableFuture.completedFuture(null);
+                       }
+
+                       @Override
+                       protected void handshakeCompleted(ChannelHandlerContext 
ctx, WsParameters parameters) throws Exception {
+                               // handshake complete; register for server 
events
+                               JobID jobID = 
parameters.getPathParameter(JobIDPathParameter.class);
+                               eventProvider.register(ctx.channel(), jobID);
+                               LATCH.countDown();
+                       }
+
+                       @Override
+                       protected void messageReceived(ChannelHandlerContext 
ctx, WsParameters parameters, TestMessage msg) throws Exception {
+                               
ctx.channel().writeAndFlush(msg.incr()).addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
+                       }
+               }
+       }
+
        static class JobIDPathParameter extends MessagePathParameter<JobID> {
                JobIDPathParameter() {
                        super(JOB_ID_KEY);
@@ -373,4 +535,17 @@ public String convertStringToValue(JobID value) {
                        return value.toString();
                }
        }
+
+       static class TestMessage implements ResponseBody, RequestBody {
+               public final int sequenceNumber;
+
+               @JsonCreator
+               public TestMessage(@JsonProperty("sequenceNumber") int 
sequenceNumber) {
+                       this.sequenceNumber = sequenceNumber;
+               }
+
+               public TestMessage incr() {
+                       return new TestMessage(sequenceNumber + 1);
+               }
+       }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/rest/handler/AbstractWebSocketHandlerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/rest/handler/AbstractWebSocketHandlerTest.java
new file mode 100644
index 00000000000..951ecfa8899
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/rest/handler/AbstractWebSocketHandlerTest.java
@@ -0,0 +1,456 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.rest.handler;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.time.Time;
+import org.apache.flink.runtime.rest.messages.MessageParameters;
+import org.apache.flink.runtime.rest.messages.MessagePathParameter;
+import org.apache.flink.runtime.rest.messages.MessageQueryParameter;
+import org.apache.flink.runtime.rest.messages.RequestBody;
+import org.apache.flink.runtime.rest.messages.ResponseBody;
+import org.apache.flink.runtime.rest.messages.WebSocketSpecification;
+import org.apache.flink.runtime.rpc.RpcUtils;
+import org.apache.flink.runtime.webmonitor.RestfulGateway;
+import org.apache.flink.runtime.webmonitor.retriever.GatewayRetriever;
+
+import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
+import org.apache.flink.shaded.netty4.io.netty.channel.Channel;
+import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext;
+import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandler;
+import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInitializer;
+import 
org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel;
+import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.FullHttpRequest;
+import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpHeaders;
+import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpRequestDecoder;
+import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponse;
+import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponseEncoder;
+import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponseStatus;
+import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.router.Handler;
+import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.router.Router;
+import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker13;
+import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
+import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.websocketx.WebSocketVersion;
+import org.apache.flink.shaded.netty4.io.netty.util.AbstractReferenceCounted;
+import org.apache.flink.shaded.netty4.io.netty.util.ReferenceCountUtil;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.net.URI;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.LinkedBlockingQueue;
+
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests the {@link AbstractWebSocketHandler}.
+ */
+public class AbstractWebSocketHandlerTest {
+
+       private static final String TEST_SUBPROTOCOL = "test";
+       private static final String REST_ADDRESS = "http://localhost:1234";;
+       private static final JobID TEST_JOB_ID = new JobID();
+
+       private GatewayRetriever<RestfulGateway> gatewayRetriever;
+
+       @Before
+       public void setup() throws Exception {
+               // setup a mock gateway and retriever to satisfy the 
RedirectHandler
+               RestfulGateway mockRestfulGateway = mock(RestfulGateway.class);
+               
when(mockRestfulGateway.requestRestAddress(any(Time.class))).thenReturn(CompletableFuture.completedFuture(REST_ADDRESS));
+               gatewayRetriever = mock(GatewayRetriever.class);
+               
when(gatewayRetriever.getNow()).thenReturn(Optional.of(mockRestfulGateway));
+       }
+
+       /**
+        * Tests the parameter processing logic.
+        */
+       @Test
+       public void testInvalidMessageParameter() {
+
+               EmbeddedChannel channel = channel(TestSpecification.INSTANCE, 
new TestHandler());
+
+               // write a handshake request that is lacking a mandatory 
parameter
+               TestParameters params = TestParameters.from(TEST_JOB_ID, 42);
+               URI requestUri = 
URI.create(requestUri(TestSpecification.INSTANCE, params).toString()
+                       .replace("42", "NA"));
+               FullHttpRequest request = handshakeRequest(requestUri, 
TEST_SUBPROTOCOL);
+               channel.writeInbound(request);
+
+               // verify that the request was rejected
+               HttpResponse response = (HttpResponse) channel.readOutbound();
+               Assert.assertEquals(0, request.refCnt());
+               Assert.assertEquals(HttpResponseStatus.BAD_REQUEST, 
response.getStatus());
+               Assert.assertTrue(channel.isOpen());
+       }
+
+       /**
+        * Tests the parameter processing logic.
+        */
+       @Test
+       public void testMissingMessageParameter() {
+
+               EmbeddedChannel channel = channel(TestSpecification.INSTANCE, 
new TestHandler());
+
+               // write a handshake request that is lacking a mandatory 
parameter
+               TestParameters params = TestParameters.from(TEST_JOB_ID, 42);
+               URI requestUri = 
URI.create(requestUri(TestSpecification.INSTANCE, params).toString()
+                       .replace("q", "_"));
+               FullHttpRequest request = handshakeRequest(requestUri, 
TEST_SUBPROTOCOL);
+               channel.writeInbound(request);
+
+               // verify that the request was rejected
+               HttpResponse response = (HttpResponse) channel.readOutbound();
+               Assert.assertEquals(0, request.refCnt());
+               Assert.assertEquals(HttpResponseStatus.BAD_REQUEST, 
response.getStatus());
+               Assert.assertTrue(channel.isOpen());
+       }
+
+       /**
+        * Tests the parameter processing logic.
+        */
+       @Test
+       public void testValidMessageParameters() {
+
+               TestHandler handler = new TestHandler();
+               EmbeddedChannel channel = channel(TestSpecification.INSTANCE, 
handler);
+
+               // write a handshake request with the mandatory parameter
+               TestParameters params = TestParameters.from(TEST_JOB_ID, 42);
+               FullHttpRequest request = 
handshakeRequest(requestUri(TestSpecification.INSTANCE, params), 
TEST_SUBPROTOCOL);
+               channel.writeInbound(request);
+
+               // verify that the request was accepted and that the message 
parameters were available in numerous callbacks
+               HttpResponse response = (HttpResponse) channel.readOutbound();
+               Assert.assertEquals(0, request.refCnt());
+               Assert.assertEquals(HttpResponseStatus.SWITCHING_PROTOCOLS, 
response.getStatus());
+               Assert.assertNotNull(handler.handshakeInitiated);
+               Assert.assertEquals(Collections.singletonList(42), 
handler.handshakeInitiated.queryParameter.getValue());
+       }
+
+       /**
+        * Tests the websocket upgrade logic.
+        */
+       @Test
+       public void testWebSocketUpgrade() {
+
+               EmbeddedChannel channel = channel(TestSpecification.INSTANCE, 
new TestHandler());
+
+               // write a handshake request
+               TestParameters params = TestParameters.from(TEST_JOB_ID, 42);
+               FullHttpRequest request = 
handshakeRequest(requestUri(TestSpecification.INSTANCE, params), 
TEST_SUBPROTOCOL);
+               channel.writeInbound(request);
+
+               // check for a websocket handshaker and message handler in the 
pipeline
+               
Assert.assertNotNull(channel.pipeline().get(WebSocketServerProtocolHandler.class));
+               
Assert.assertNotNull(channel.pipeline().get(JsonWebSocketMessageCodec.class));
+
+               // check for a websocket handshake response
+               HttpResponse response = (HttpResponse) channel.readOutbound();
+               Assert.assertEquals(0, request.refCnt());
+               Assert.assertNotNull(response);
+               Assert.assertEquals(HttpResponseStatus.SWITCHING_PROTOCOLS, 
response.getStatus());
+       }
+
+       /**
+        * Tests upgrade failure due to a handshaking issue.
+        */
+       @Test
+       public void testWebSocketUpgradeFailure() {
+
+               EmbeddedChannel channel = channel(TestSpecification.INSTANCE, 
new TestHandler());
+
+               // write an invalid handshake request
+               TestParameters params = TestParameters.from(TEST_JOB_ID, 42);
+               FullHttpRequest request = 
handshakeRequest(requestUri(TestSpecification.INSTANCE, params), 
TEST_SUBPROTOCOL);
+               request.headers().remove("Sec-WebSocket-Key");
+               channel.writeInbound(request);
+
+               // check for a websocket handshake response
+               HttpResponse response = (HttpResponse) channel.readOutbound();
+               Assert.assertEquals(0, request.refCnt());
+               Assert.assertNotNull(response);
+               Assert.assertEquals(HttpResponseStatus.BAD_REQUEST, 
response.getStatus());
+               Assert.assertFalse(channel.isOpen());
+       }
+
+       /**
+        * Tests message reading and writing.
+        */
+       @Test
+       public void testMessageReadWrite() throws Exception {
+               TestHandler handler = new TestHandler();
+               EmbeddedChannel channel = channel(TestSpecification.INSTANCE, 
handler);
+
+               TestParameters params = TestParameters.from(TEST_JOB_ID, 42);
+               FullHttpRequest request = 
handshakeRequest(requestUri(TestSpecification.INSTANCE, params), 
TEST_SUBPROTOCOL);
+               channel.writeInbound(request);
+               HttpResponse response = (HttpResponse) channel.readOutbound();
+               Assert.assertEquals(HttpResponseStatus.SWITCHING_PROTOCOLS, 
response.getStatus());
+
+               ClientMessage expected = new ClientMessage(42);
+               channel.writeInbound(expected);
+               ClientMessage actual = handler.messages.take();
+               Assert.assertEquals(expected.sequenceNumber, 
actual.sequenceNumber);
+               Assert.assertEquals(1, expected.refCnt());
+               Assert.assertTrue(channel.isOpen());
+
+               ByteBuf msg = (ByteBuf) channel.readOutbound();
+               Assert.assertEquals(23, msg.readableBytes()); // contains an 
on-the-wire text frame
+               ReferenceCountUtil.release(msg);
+       }
+
+       /**
+        * A path parameter for test purposes.
+        */
+       static class TestPathParameter extends MessagePathParameter<JobID> {
+               TestPathParameter() {
+                       super("jobid");
+               }
+
+               @Override
+               public JobID convertFromString(String value) {
+                       return JobID.fromHexString(value);
+               }
+
+               @Override
+               protected String convertToString(JobID value) {
+                       return value.toString();
+               }
+       }
+
+       /**
+        * A query parameter for test purposes.
+        */
+       private static class TestQueryParameter extends 
MessageQueryParameter<Integer> {
+
+               TestQueryParameter() {
+                       super("q", MessageParameterRequisiteness.MANDATORY);
+               }
+
+               @Override
+               public Integer convertValueFromString(String value) {
+                       return Integer.parseInt(value);
+               }
+
+               @Override
+               public String convertStringToValue(Integer value) {
+                       return value.toString();
+               }
+       }
+
+       /**
+        * Parameters for the test WebSocket resource.
+        */
+       private static class TestParameters extends MessageParameters {
+               private final TestPathParameter pathParameter = new 
TestPathParameter();
+               private final TestQueryParameter queryParameter = new 
TestQueryParameter();
+
+               @Override
+               public Collection<MessagePathParameter<?>> getPathParameters() {
+                       return Collections.singleton(pathParameter);
+               }
+
+               @Override
+               public Collection<MessageQueryParameter<?>> 
getQueryParameters() {
+                       return Collections.singleton(queryParameter);
+               }
+
+               public static TestParameters from(JobID jobID, int 
sequenceNumber) {
+                       TestParameters p = new TestParameters();
+                       p.pathParameter.resolve(jobID);
+                       
p.queryParameter.resolve(Collections.singletonList(sequenceNumber));
+                       return p;
+               }
+       }
+
+       /**
+        * The specification for the test WebSocket resource.
+        */
+       static class TestSpecification implements 
WebSocketSpecification<TestParameters, ClientMessage, ServerMessage> {
+
+               static final TestSpecification INSTANCE = new 
TestSpecification();
+
+               @Override
+               public String getTargetRestEndpointURL() {
+                       return "/jobs/:jobid/subscribe";
+               }
+
+               @Override
+               public String getSubprotocol() {
+                       return "test";
+               }
+
+               @Override
+               public Class<ServerMessage> getServerClass() {
+                       return ServerMessage.class;
+               }
+
+               @Override
+               public Class<ClientMessage> getClientClass() {
+                       return ClientMessage.class;
+               }
+
+               @Override
+               public TestParameters getUnresolvedMessageParameters() {
+                       return new TestParameters();
+               }
+       }
+
+       /**
+        * The channel handler for the test WebSocket resource.
+        */
+       private class TestHandler extends 
AbstractWebSocketHandler<RestfulGateway, TestParameters, ClientMessage, 
ServerMessage> {
+
+               TestParameters handshakeInitiated = null;
+
+               TestParameters handshakeCompleted = null;
+
+               final LinkedBlockingQueue<ClientMessage> messages = new 
LinkedBlockingQueue<>();
+
+               public TestHandler() {
+                       super(CompletableFuture.completedFuture(REST_ADDRESS), 
gatewayRetriever, RpcUtils.INF_TIMEOUT, TestSpecification.INSTANCE);
+               }
+
+               @Override
+               protected CompletableFuture<Void> 
handshakeInitiated(ChannelHandlerContext ctx, TestParameters parameters) throws 
Exception {
+                       handshakeInitiated = parameters;
+                       return CompletableFuture.completedFuture(null);
+               }
+
+               @Override
+               protected void handshakeCompleted(ChannelHandlerContext ctx, 
TestParameters parameters) throws Exception {
+                       handshakeCompleted = parameters;
+               }
+
+               @Override
+               protected void messageReceived(ChannelHandlerContext ctx, 
TestParameters parameters, ClientMessage msg) throws Exception {
+                       ReferenceCountUtil.retain(msg);
+                       messages.put(msg);
+
+                       ctx.channel().writeAndFlush(new 
ServerMessage(msg.sequenceNumber));
+               }
+       }
+
+       /**
+        * A WebSocket message for test purposes.
+        */
+       private static class ClientMessage extends AbstractReferenceCounted 
implements RequestBody {
+               public final int sequenceNumber;
+
+               public ClientMessage(@JsonProperty("sequenceNumber") int 
sequenceNumber) {
+                       this.sequenceNumber = sequenceNumber;
+               }
+
+               @Override
+               protected void deallocate() {
+               }
+       }
+
+       /**
+        * A WebSocket message for test purposes.
+        */
+       private static class ServerMessage extends AbstractReferenceCounted 
implements ResponseBody {
+               public final int sequenceNumber;
+
+               public ServerMessage(@JsonProperty("sequenceNumber") int 
sequenceNumber) {
+                       this.sequenceNumber = sequenceNumber;
+               }
+
+               @Override
+               protected void deallocate() {
+               }
+       }
+
+       // ----- Utility -----
+
+       /**
+        * Creates an embedded channel for HTTP/websocket test purposes.
+        */
+       private EmbeddedChannel channel(RestHandlerSpecification spec, 
ChannelInboundHandler handler) {
+
+               // the websocket handshaker looks for HTTP decoder/encoder (to 
know where to inject WS handlers).
+               // For test purposes, use a passthru encoder so that tests may 
assert various aspects of the response.
+               return new EmbeddedChannel(new ChannelInitializer<Channel>() {
+                       @Override
+                       protected void initChannel(Channel channel) throws 
Exception {
+                               Router router = new Router();
+                               router.GET(spec.getTargetRestEndpointURL(), 
handler);
+                               Handler routerHandler = new Handler(router);
+                               channel.pipeline()
+                                       .addLast(new HttpRequestDecoder(), new 
MockHttpResponseEncoder())
+                                       .addLast(routerHandler.name(), 
routerHandler);
+                       }
+               });
+       }
+
+       /**
+        * Creates an URI for the given REST spec.
+        */
+       private static URI requestUri(RestHandlerSpecification spec, 
MessageParameters parameters) {
+               URI resolved = 
URI.create(MessageParameters.resolveUrl(spec.getTargetRestEndpointURL(), 
parameters));
+               return URI.create(REST_ADDRESS).resolve(resolved);
+       }
+
+       /**
+        * Creates a handshake request for the given request URI.
+        */
+       private static FullHttpRequest handshakeRequest(URI requestUri, String 
subprotocol) {
+               URI wsURL = URI.create(requestUri.toString().replace("http:", 
"ws:").replace("https:", "wss:"));
+               MockHandshaker handshaker = new MockHandshaker(wsURL, 
subprotocol, HttpHeaders.EMPTY_HEADERS);
+               FullHttpRequest httpRequest = handshaker.newHandshakeRequest();
+               return httpRequest;
+       }
+
+       /**
+        * A helper class for using Netty's built-in handshaking class to 
construct valid handshake requests.
+        */
+       private static class MockHandshaker extends WebSocketClientHandshaker13 
{
+               public MockHandshaker(URI webSocketURL, String subprotocol, 
HttpHeaders customHeaders) {
+                       super(webSocketURL, WebSocketVersion.V13, subprotocol, 
false, customHeaders, Integer.MAX_VALUE);
+               }
+
+               @Override
+               public FullHttpRequest newHandshakeRequest() {
+                       return super.newHandshakeRequest();
+               }
+       }
+
+       /**
+        * A no-op encoder that must exist to satisfy the WS handshaker.
+        */
+       private static class MockHttpResponseEncoder extends 
HttpResponseEncoder {
+               @Override
+               protected void encode(ChannelHandlerContext ctx, Object msg, 
List<Object> out) throws Exception {
+                       if (msg != null) {
+                               ReferenceCountUtil.retain(msg);
+                               out.add(msg);
+                       }
+               }
+       }
+}
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/rest/handler/JsonWebSocketMessageCodecTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/rest/handler/JsonWebSocketMessageCodecTest.java
new file mode 100644
index 00000000000..48283f34bb5
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/rest/handler/JsonWebSocketMessageCodecTest.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.rest.handler;
+
+import 
org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel;
+import org.apache.flink.shaded.netty4.io.netty.handler.codec.DecoderException;
+import org.apache.flink.shaded.netty4.io.netty.handler.codec.EncoderException;
+import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
+
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Tests the {@link JsonWebSocketMessageCodec}.
+ */
+public class JsonWebSocketMessageCodecTest {
+
+       @Test
+       public void testDecodeSuccess() {
+               EmbeddedChannel channel = new EmbeddedChannel(new 
JsonWebSocketMessageCodec<>(TestMessage.class, TestMessage.class));
+               channel.writeInbound(new 
TextWebSocketFrame("{\"sequenceNumber\":42}"));
+               TestMessage actual = (TestMessage) channel.readInbound();
+               Assert.assertNotNull(actual);
+               Assert.assertEquals(42, actual.sequenceNumber);
+       }
+
+       @Test(expected = DecoderException.class)
+       public void testDecodeFailure() {
+               EmbeddedChannel channel = new EmbeddedChannel(new 
JsonWebSocketMessageCodec<>(TestMessage.class, TestMessage.class));
+               channel.writeInbound(new TextWebSocketFrame(""));
+       }
+
+       @Test
+       public void testEncodeSuccess() {
+               EmbeddedChannel channel = new EmbeddedChannel(new 
JsonWebSocketMessageCodec<>(TestMessage.class, TestMessage.class));
+               channel.writeOutbound(new TestMessage(42));
+               TextWebSocketFrame actual = (TextWebSocketFrame) 
channel.readOutbound();
+               Assert.assertNotNull(actual);
+               Assert.assertEquals("{\"sequenceNumber\":42}", actual.text());
+       }
+
+       @Test(expected = EncoderException.class)
+       public void testEncodeFailure() {
+               EmbeddedChannel channel = new EmbeddedChannel(new 
JsonWebSocketMessageCodec<>(TestMessage.class, TestMessage.class));
+               channel.writeOutbound(new TestMessage(-1));
+       }
+
+       private static class TestMessage {
+
+               private final int sequenceNumber;
+
+               public int getSequenceNumber() {
+                       if (sequenceNumber < 0) {
+                               throw new IllegalStateException();
+                       }
+                       return sequenceNumber;
+               }
+
+               @JsonCreator
+               public TestMessage(@JsonProperty("sequenceNumber") int 
sequenceNumber) {
+                       this.sequenceNumber = sequenceNumber;
+               }
+       }
+}


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


> Create WebSocket handler (server)
> ---------------------------------
>
>                 Key: FLINK-7738
>                 URL: https://issues.apache.org/jira/browse/FLINK-7738
>             Project: Flink
>          Issue Type: Sub-task
>          Components: Cluster Management, Mesos
>            Reporter: Eron Wright 
>            Assignee: Eron Wright 
>            Priority: Major
>              Labels: pull-request-available
>
> An abstract handler is needed to support websocket communication.



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Reply via email to