Github user EronWright commented on a diff in the pull request:

    https://github.com/apache/flink/pull/4767#discussion_r144381682
  
    --- Diff: 
flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/AbstractWebSocketHandler.java
 ---
    @@ -0,0 +1,301 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *     http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.flink.runtime.rest.handler;
    +
    +import org.apache.flink.api.common.time.Time;
    +import org.apache.flink.runtime.concurrent.FutureUtils;
    +import org.apache.flink.runtime.rest.handler.util.HandlerUtils;
    +import org.apache.flink.runtime.rest.messages.ErrorResponseBody;
    +import org.apache.flink.runtime.rest.messages.MessageParameters;
    +import org.apache.flink.runtime.rest.messages.RequestBody;
    +import org.apache.flink.runtime.rest.messages.ResponseBody;
    +import org.apache.flink.runtime.rest.messages.WebSocketSpecification;
    +import org.apache.flink.runtime.webmonitor.RestfulGateway;
    +import org.apache.flink.runtime.webmonitor.retriever.GatewayRetriever;
    +import org.apache.flink.util.ExceptionUtils;
    +import org.apache.flink.util.Preconditions;
    +
    +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandler;
    +import 
org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext;
    +import 
org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandlerAdapter;
    +import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpRequest;
    +import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponseStatus;
    +import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.router.Routed;
    +import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
    +import org.apache.flink.shaded.netty4.io.netty.util.AttributeKey;
    +import org.apache.flink.shaded.netty4.io.netty.util.ReferenceCountUtil;
    +
    +import org.slf4j.Logger;
    +import org.slf4j.LoggerFactory;
    +
    +import javax.annotation.Nonnull;
    +
    +import java.util.concurrent.CompletableFuture;
    +
    +/**
    + * A channel handler for WebSocket resources.
    + *
    + * <p>This handler handles handshaking and ongoing messaging with a 
WebSocket client,
    + * based on a {@link WebSocketSpecification} that describes the REST 
resource location,
    + * parameter type, and message inbound/outbound types.  Messages are 
automatically
    + * encoded from (and decoded to) JSON text.
    + *
    + * <p>Subclasses should override the following methods to extend the 
respective phases.
    + * <ol>
    + *     <li>{@code handshakeInitiated} - occurs upon receipt of a handshake 
request from an HTTP client.  Useful for parameter validation.</li>
    + *     <li>{@code handshakeCompleted} - occurs upon successful completion; 
WebSocket is ready for I/O.</li>
    + *     <li>{@code messageReceived}: occurs when a WebSocket message is 
received on the channel.</li>
    + * </ol>
    + *
    + * <p>The handler supports gateway availability announcements.
    + *
    + * @param <T> The gateway type.
    + * @param <M> The REST parameter type.
    + * @param <O> The outbound message type.
    + * @param <I> The inbound message type.
    + */
    +public abstract class AbstractWebSocketHandler<T extends RestfulGateway, M 
extends MessageParameters, O extends RequestBody, I extends ResponseBody> 
extends ChannelInboundHandlerAdapter {
    +
    +   protected final Logger log = LoggerFactory.getLogger(getClass());
    +
    +   private final RedirectHandler redirectHandler;
    +
    +   private final AttributeKey<T> gatewayAttr;
    +
    +   private final WebSocketSpecification<M, O, I> specification;
    +
    +   private final ChannelHandler messageCodec;
    +
    +   private final AttributeKey<M> parametersAttr;
    +
    +   /**
    +    * Creates a new handler.
    +    */
    +   public AbstractWebSocketHandler(
    +           @Nonnull CompletableFuture<String> localAddressFuture,
    +           @Nonnull GatewayRetriever<? extends T> leaderRetriever,
    +           @Nonnull Time timeout,
    +           @Nonnull WebSocketSpecification<M, O, I> specification) {
    +           this.redirectHandler = new 
RedirectHandler<>(localAddressFuture, leaderRetriever, timeout);
    +           this.gatewayAttr = AttributeKey.valueOf("gateway");
    +           this.specification = specification;
    +           this.messageCodec = new 
JsonWebSocketMessageCodec<>(specification.getInboundClass(), 
specification.getOutboundClass());
    +           this.parametersAttr = AttributeKey.valueOf("parameters");
    +   }
    +
    +   /**
    +    * Sets the gateway associated with the channel.
    +    */
    +   private void setGateway(ChannelHandlerContext ctx, T gateway) {
    +           ctx.attr(gatewayAttr).set(gateway);
    +   }
    +
    +   /**
    +    * Returns the gateway associated with the channel.
    +    */
    +   public T getGateway(ChannelHandlerContext ctx) {
    +           T t = ctx.attr(gatewayAttr).get();
    +           Preconditions.checkState(t != null, "Gateway is not 
available.");
    +           return t;
    +   }
    +
    +   /**
    +    * Sets the resource parameters associated with the channel.
    +    *
    +    * <p>The parameters are established by the WebSocket handshake request.
    +    */
    +   private void setMessageParameters(ChannelHandlerContext ctx, M 
parameters) {
    +           ctx.attr(parametersAttr).set(parameters);
    +   }
    +
    +   /**
    +    * Returns the resource parameters associated with the channel.
    +    */
    +   public M getMessageParameters(ChannelHandlerContext ctx) {
    +           M o = ctx.attr(parametersAttr).get();
    +           Preconditions.checkState(o != null, "Message parameters are not 
available.");
    +           return o;
    +   }
    +
    +   @Override
    +   public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
    +           if (ctx.pipeline().get(RedirectHandler.class.getName()) == 
null) {
    +                   ctx.pipeline().addBefore(ctx.name(), 
RedirectHandler.class.getName(), redirectHandler);
    +           }
    +   }
    +
    +   @Override
    +   @SuppressWarnings("unchecked")
    +   public void userEventTriggered(ChannelHandlerContext ctx, Object evt) 
throws Exception {
    +           if (evt instanceof RedirectHandler.GatewayRetrieved) {
    +                   T gateway = ((RedirectHandler.GatewayRetrieved<T>) 
evt).getGateway();
    +                   setGateway(ctx, gateway);
    +                   log.debug("Gateway retrieved: {}", gateway);
    +           }
    +           else if (evt instanceof 
WebSocketServerProtocolHandler.ServerHandshakeStateEvent) {
    +                   
WebSocketServerProtocolHandler.ServerHandshakeStateEvent handshakeEvent = 
(WebSocketServerProtocolHandler.ServerHandshakeStateEvent) evt;
    +                   if (handshakeEvent == 
WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE) {
    +                           log.debug("Handshake completed with client IP: 
{}", ctx.channel().remoteAddress());
    +                           M parameters = getMessageParameters(ctx);
    +                           handshakeCompleted(ctx, parameters);
    +                   }
    +           }
    +
    +           super.userEventTriggered(ctx, evt);
    +   }
    +
    +   @Override
    +   @SuppressWarnings("unchecked")
    +   public void channelRead(ChannelHandlerContext ctx, Object o) throws 
Exception {
    +           if 
(specification.getInboundClass().isAssignableFrom(o.getClass())) {
    +                   // process an inbound message
    +                   M parameters = getMessageParameters(ctx);
    +                   messageReceived(ctx, parameters, (O) o);
    +                   return;
    +           }
    +
    +           if (!(o instanceof Routed)) {
    +                   // a foreign message
    +                   ctx.fireChannelRead(o);
    +                   return;
    +           }
    +
    +           // process an inbound HTTP request
    +           Routed request = (Routed) o;
    +
    +           // parse the REST request parameters
    +           M messageParameters = 
specification.getUnresolvedMessageParameters();
    +           try {
    +                   
messageParameters.resolveParameters(request.pathParams(), 
request.queryParams());
    +                   if (!messageParameters.isResolved()) {
    +                           throw new IllegalArgumentException("One or more 
mandatory parameters is missing.");
    +                   }
    +           }
    +           catch (IllegalArgumentException e) {
    +                   HandlerUtils.sendErrorResponse(
    +                           ctx,
    +                           request.request(),
    +                           new ErrorResponseBody(String.format("Bad 
request, could not parse parameters: %s", e.getMessage())),
    +                           HttpResponseStatus.BAD_REQUEST);
    +                   ReferenceCountUtil.release(request);
    +                   return;
    +           }
    +           setMessageParameters(ctx, messageParameters);
    +
    +           // validate the inbound handshake request with the subclass
    +           CompletableFuture<Void> handshakeReady;
    +           try {
    +                   handshakeReady = handshakeInitiated(ctx, 
messageParameters);
    +           } catch (Exception e) {
    +                   handshakeReady = FutureUtils.completedExceptionally(e);
    +           }
    +           handshakeReady.whenCompleteAsync((Void v, Throwable throwable) 
-> {
    +                   try {
    +                           if (throwable != null) {
    +                                   Throwable error = 
ExceptionUtils.stripCompletionException(throwable);
    +                                   if (error instanceof 
RestHandlerException) {
    +                                           final RestHandlerException rhe 
= (RestHandlerException) error;
    +                                           log.error("Exception occurred 
in REST handler.", error);
    +                                           
HandlerUtils.sendErrorResponse(ctx, request.request(), new 
ErrorResponseBody(rhe.getMessage()), rhe.getHttpResponseStatus());
    +                                   } else {
    +                                           log.error("Implementation 
error: Unhandled exception.", error);
    +                                           
HandlerUtils.sendErrorResponse(ctx, request.request(), new 
ErrorResponseBody("Internal server error."), 
HttpResponseStatus.INTERNAL_SERVER_ERROR);
    +                                   }
    +                           } else {
    +                                   upgradeToWebSocket(ctx, request);
    +                           }
    +                   }
    +                   finally {
    +                           ReferenceCountUtil.release(request);
    +                   }
    +           }, ctx.executor());
    +   }
    +
    +   private void upgradeToWebSocket(final ChannelHandlerContext ctx, Routed 
msg) {
    +
    +           // store the context of the handler that precedes the current 
handler,
    +           // to use that context later to forward the HTTP request to the 
WebSocket protocol handler
    +           String before = 
ctx.pipeline().names().get(ctx.pipeline().names().indexOf(ctx.name()) - 1);
    +           ChannelHandlerContext beforeCtx = 
ctx.pipeline().context(before);
    +
    +           // inject the websocket protocol handler into this channel, to 
be active
    +           // until the channel is closed.  note that the handshake may or 
may not complete synchronously.
    +           ctx.pipeline().addBefore(ctx.name(), 
WebSocketServerProtocolHandler.class.getName(),
    +                   new WebSocketServerProtocolHandler(msg.path(), 
specification.getSubprotocol()));
    +
    +           // inject the message codec
    +           ctx.pipeline().addBefore(ctx.name(), 
messageCodec.getClass().getName(), messageCodec);
    +
    +           log.debug("Upgraded channel with WS protocol handler and 
message codec.");
    +
    +           // forward the message to the installed protocol handler to 
initiate handshaking
    +           HttpRequest request = msg.request();
    +           ReferenceCountUtil.retain(request);
    +           beforeCtx.fireChannelRead(request);
    +   }
    +
    +   @Override
    +   public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) 
throws Exception {
    +           log.error("WebSocket channel error; closing the channel.", 
cause);
    +           ctx.close();
    +   }
    +
    +   /**
    +    * Handles a client handshake request to open a WebSocket resource.  
Returns a {@link CompletableFuture} to complete handshaking.
    +    *
    +    * <p>Implementations may decide whether to throw {@link 
RestHandlerException}s or fail the returned
    +    * {@link CompletableFuture} with a {@link RestHandlerException}.
    +    *
    +    * <p>Failing the future with another exception type or throwing 
unchecked exceptions is regarded as an
    +    * implementation error as it does not allow us to provide a meaningful 
HTTP status code. In this case a
    +    * {@link HttpResponseStatus#INTERNAL_SERVER_ERROR} will be returned.
    +    *
    +    * @param parameters the REST parameters
    +
    +    * @return future indicating completion of handshake pre-processing.
    +    * @throws RestHandlerException to produce a pre-formatted HTTP error 
response.
    +    */
    +   protected CompletableFuture<Void> 
handshakeInitiated(ChannelHandlerContext ctx, M parameters) throws Exception {
    +           return CompletableFuture.completedFuture(null);
    +   }
    +
    +   /**
    +    * Invoked when the current channel has completed the handshaking to 
establish a WebSocket connection.
    +    *
    +    * @param ctx the channel handler context
    +    * @param parameters the REST parameters
    +    * @throws Exception if processing failed.
    +    */
    +   protected void handshakeCompleted(ChannelHandlerContext ctx, M 
parameters) throws Exception {
    +   }
    +
    +   /**
    +    * Invoked when the current channel has received a WebSocket message.
    +    *
    +    * <p>Be sure to release the message object (default behavior).
    +    *
    +    * @param ctx the channel handler context
    +    * @param parameters the REST parameters
    +    * @param msg the message received
    +    * @throws Exception if the message could not be processed.
    +    */
    +   protected void messageReceived(ChannelHandlerContext ctx, M parameters, 
O msg) throws Exception {
    +           ReferenceCountUtil.release(msg);
    --- End diff --
    
    I think I will change the behavior to be more consistent with 
`SimpleChannelInboundHandler` and auto-release the message rather than forcing 
the subclass to release it.  Note that `channelRead0` is being renamed to 
`messageReceived` in Netty 5.


---

Reply via email to