This is an automated email from the ASF dual-hosted git repository.

jshao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-livy.git


The following commit(s) were added to refs/heads/master by this push:
     new 66b5833  [LIVY-735][RSC] Fix rpc channel closed when multi clients 
connect to one driver
66b5833 is described below

commit 66b5833e413bc10e39e3b92b585f496444c147d4
Author: runzhiwang <runzhiw...@tencent.com>
AuthorDate: Wed Jan 8 17:15:04 2020 +0800

    [LIVY-735][RSC] Fix rpc channel closed when multi clients connect to one 
driver
    
    ## What changes were proposed in this pull request?
    
    Currently, the driver tries to support communicating with multi-clients, by 
registering each client at 
https://github.com/apache/incubator-livy/blob/master/rsc/src/main/java/org/apache/livy/rsc/driver/RSCDriver.java#L220.
    
    But actually, if multi-clients connect to one driver, the rpc channel will 
close, the reason are as follows.
    
    1.  In every communication, client sends two packages to driver: 
header{type, id}, and payload at 
https://github.com/apache/incubator-livy/blob/master/rsc/src/main/java/org/apache/livy/rsc/rpc/RpcDispatcher.java#L144.
    
    2. If client1 sends header1, payload1, and client2 sends header2, payload2 
at the same time.
      The driver receives the package in the order: header1, header2, payload1, 
payload2.
    
    3. When driver receives header1, driver assigns lastHeader at 
https://github.com/apache/incubator-livy/blob/master/rsc/src/main/java/org/apache/livy/rsc/rpc/RpcDispatcher.java#L73.
    
    4. Then driver receives header2, driver process it as a payload at 
https://github.com/apache/incubator-livy/blob/master/rsc/src/main/java/org/apache/livy/rsc/rpc/RpcDispatcher.java#L78
 which cause exception and rpc channel closed.
    
    In the muti-active HA mode, the design doc is at: 
https://docs.google.com/document/d/1bD3qYZpw14_NuCcSGUOfqQ0pqvSbCQsOLFuZp26Ohjc/edit?usp=sharing,
 the session is allocated among servers by consistent hashing. If a new livy 
joins, some session will be migrated from old livy to new livy. If the session 
client in new livy connect to driver before stoping session client in old livy, 
then two session clients will both connect to driver, and rpc channel close.  
In this case, it's hard to e [...]
    
    How to fix:
    1. Move the code of processing client message from `RpcDispatcher` to each 
`Rpc`.
    2. Each `Rpc` registers itself to `channelRpc` in RpcDispatcher.
    3. `RpcDispatcher` dispatches each message to `Rpc` according to  
`ctx.channel()`.
    
    ## How was this patch tested?
    
    Existed UT and IT
    
    Author: runzhiwang <runzhiw...@tencent.com>
    
    Closes #268 from runzhiwang/multi-client-one-driver.
---
 .../java/org/apache/livy/rsc/driver/RSCDriver.java |   1 +
 rsc/src/main/java/org/apache/livy/rsc/rpc/Rpc.java | 185 ++++++++++++++++++++-
 .../org/apache/livy/rsc/rpc/RpcDispatcher.java     | 167 ++-----------------
 3 files changed, 196 insertions(+), 157 deletions(-)

diff --git a/rsc/src/main/java/org/apache/livy/rsc/driver/RSCDriver.java 
b/rsc/src/main/java/org/apache/livy/rsc/driver/RSCDriver.java
index 0d8eec5..a8f31f7 100644
--- a/rsc/src/main/java/org/apache/livy/rsc/driver/RSCDriver.java
+++ b/rsc/src/main/java/org/apache/livy/rsc/driver/RSCDriver.java
@@ -224,6 +224,7 @@ public class RSCDriver extends BaseProtocol {
       @Override
       public void onSuccess(Void unused) {
         clients.remove(client);
+        client.unRegisterRpc();
         if (!inShutdown.get()) {
           setupIdleTimeout();
         }
diff --git a/rsc/src/main/java/org/apache/livy/rsc/rpc/Rpc.java 
b/rsc/src/main/java/org/apache/livy/rsc/rpc/Rpc.java
index 868dc6d..5fce164 100644
--- a/rsc/src/main/java/org/apache/livy/rsc/rpc/Rpc.java
+++ b/rsc/src/main/java/org/apache/livy/rsc/rpc/Rpc.java
@@ -19,10 +19,11 @@ package org.apache.livy.rsc.rpc;
 
 import java.io.Closeable;
 import java.io.IOException;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.LinkedList;
-import java.util.Map;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.util.*;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicBoolean;
@@ -208,6 +209,7 @@ public class Rpc implements Closeable {
         dispatcher);
     Rpc rpc = new Rpc(new RSCConf(null), c, ImmediateEventExecutor.INSTANCE);
     rpc.dispatcher = dispatcher;
+    dispatcher.registerRpc(c, rpc);
     return rpc;
   }
 
@@ -218,6 +220,10 @@ public class Rpc implements Closeable {
   private final EventExecutorGroup egroup;
   private volatile RpcDispatcher dispatcher;
 
+  private final Map<Class<?>, Method> handlers = new ConcurrentHashMap<>();
+  private final Collection<OutstandingRpc> rpcCalls = new 
ConcurrentLinkedQueue<OutstandingRpc>();
+  private volatile Rpc.MessageHeader lastHeader;
+
   private Rpc(RSCConf config, Channel channel, EventExecutorGroup egroup) {
     Utils.checkArgument(channel != null);
     Utils.checkArgument(egroup != null);
@@ -239,6 +245,166 @@ public class Rpc implements Closeable {
   }
 
   /**
+   * For debugging purposes.
+   * @return The name of this Class.
+   */
+  protected String name() {
+    return getClass().getSimpleName();
+  }
+
+  public void handleMsg(ChannelHandlerContext ctx, Object msg, Class<?> 
handleClass, Object obj)
+      throws Exception {
+    if (lastHeader == null) {
+      if (!(msg instanceof MessageHeader)) {
+        LOG.warn("[{}] Expected RPC header, got {} instead.", name(),
+          msg != null ? msg.getClass().getName() : null);
+        throw new IllegalArgumentException();
+      }
+      lastHeader = (MessageHeader) msg;
+    } else {
+      LOG.debug("[{}] Received RPC message: type={} id={} payload={}", name(),
+        lastHeader.type, lastHeader.id, msg != null ? msg.getClass().getName() 
: null);
+      try {
+        switch (lastHeader.type) {
+          case CALL:
+            handleCall(ctx, msg, handleClass, obj);
+            break;
+          case REPLY:
+            handleReply(ctx, msg, findRpcCall(lastHeader.id));
+            break;
+          case ERROR:
+            handleError(ctx, msg, findRpcCall(lastHeader.id));
+            break;
+          default:
+            throw new IllegalArgumentException("Unknown RPC message type: " + 
lastHeader.type);
+        }
+      } finally {
+        lastHeader = null;
+      }
+    }
+  }
+
+  private void handleCall(ChannelHandlerContext ctx, Object msg, Class<?> 
handleClass, Object obj)
+      throws Exception {
+    Method handler = handlers.get(msg.getClass());
+    if (handler == null) {
+      // Try both getDeclaredMethod() and getMethod() so that we try both 
private methods
+      // of the class, and public methods of parent classes.
+      try {
+        handler = handleClass.getDeclaredMethod("handle", 
ChannelHandlerContext.class,
+            msg.getClass());
+      } catch (NoSuchMethodException e) {
+        try {
+          handler = handleClass.getMethod("handle", 
ChannelHandlerContext.class,
+              msg.getClass());
+        } catch (NoSuchMethodException e2) {
+          LOG.warn(String.format("[%s] Failed to find handler for msg '%s'.", 
name(),
+            msg.getClass().getName()));
+          writeMessage(MessageType.ERROR, 
Utils.stackTraceAsString(e.getCause()));
+          return;
+        }
+      }
+      handler.setAccessible(true);
+      handlers.put(msg.getClass(), handler);
+    }
+
+    try {
+      Object payload = handler.invoke(obj, ctx, msg);
+      if (payload == null) {
+        payload = new NullMessage();
+      }
+      writeMessage(MessageType.REPLY, payload);
+    } catch (InvocationTargetException ite) {
+      LOG.debug(String.format("[%s] Error in RPC handler.", name()), 
ite.getCause());
+      writeMessage(MessageType.ERROR, 
Utils.stackTraceAsString(ite.getCause()));
+    }
+  }
+
+  private void handleReply(ChannelHandlerContext ctx, Object msg, 
OutstandingRpc rpc) {
+    rpc.future.setSuccess(msg instanceof NullMessage ? null : msg);
+  }
+
+  private void handleError(ChannelHandlerContext ctx, Object msg, 
OutstandingRpc rpc) {
+    if (msg instanceof String) {
+      LOG.warn("Received error message:{}.", msg);
+      rpc.future.setFailure(new RpcException((String) msg));
+    } else {
+      String error = String.format("Received error with unexpected payload 
(%s).",
+          msg != null ? msg.getClass().getName() : null);
+      LOG.warn(String.format("[%s] %s", name(), error));
+      rpc.future.setFailure(new IllegalArgumentException(error));
+      ctx.close();
+    }
+  }
+
+  private void writeMessage(MessageType replyType, Object payload) {
+    channel.write(new MessageHeader(lastHeader.id, replyType));
+    channel.writeAndFlush(payload);
+  }
+
+  private OutstandingRpc findRpcCall(long id) {
+    for (Iterator<OutstandingRpc> it = rpcCalls.iterator(); it.hasNext();) {
+      OutstandingRpc rpc = it.next();
+      if (rpc.id == id) {
+        it.remove();
+        return rpc;
+      }
+    }
+    throw new IllegalArgumentException(String.format(
+        "Received RPC reply for unknown RPC (%d).", id));
+  }
+
+  private void registerRpcCall(long id, Promise<?> promise, String type) {
+    LOG.debug("[{}] Registered outstanding rpc {} ({}).", name(), id, type);
+    rpcCalls.add(new OutstandingRpc(id, promise));
+  }
+
+  private void discardRpcCall(long id) {
+    LOG.debug("[{}] Discarding failed RPC {}.", name(), id);
+    findRpcCall(id);
+  }
+
+  private static class OutstandingRpc {
+    final long id;
+    final Promise<Object> future;
+
+    @SuppressWarnings("unchecked")
+    OutstandingRpc(long id, Promise<?> future) {
+      this.id = id;
+      this.future = (Promise<Object>) future;
+    }
+  }
+
+  public void handleChannelException(ChannelHandlerContext ctx, Throwable 
cause) {
+    if (LOG.isDebugEnabled()) {
+      LOG.debug(String.format("[%s] Caught exception in channel pipeline.", 
name()), cause);
+    } else {
+      LOG.info(String.format("[%s] Caught exception in channel pipeline.", 
name()), cause);
+    }
+
+    if (lastHeader != null) {
+      // There's an RPC waiting for a reply. Exception was most probably 
caught while processing
+      // the RPC, so send an error.
+      channel.write(new MessageHeader(lastHeader.id, MessageType.ERROR));
+      channel.writeAndFlush(Utils.stackTraceAsString(cause));
+      lastHeader = null;
+    }
+
+    ctx.close();
+  }
+
+  public void handleChannelInactive() {
+    if (rpcCalls.size() > 0) {
+      LOG.warn("[{}] Closing RPC channel with {} outstanding RPCs.", name(), 
rpcCalls.size());
+      for (OutstandingRpc rpc : rpcCalls) {
+        rpc.future.cancel(true);
+      }
+    } else {
+      LOG.debug("Channel {} became inactive.", channel);
+    }
+  }
+
+  /**
    * Send an RPC call to the remote endpoint and returns a future that can be 
used to monitor the
    * operation.
    *
@@ -269,13 +435,13 @@ public class Rpc implements Closeable {
             if (!cf.isSuccess() && !promise.isDone()) {
               LOG.warn("Failed to send RPC, closing connection.", cf.cause());
               promise.setFailure(cf.cause());
-              dispatcher.discardRpc(id);
+              discardRpcCall(id);
               close();
             }
           }
       };
 
-      dispatcher.registerRpc(id, promise, msg.getClass().getName());
+      registerRpcCall(id, promise, msg.getClass().getName());
       channel.eventLoop().submit(new Runnable() {
         @Override
         public void run() {
@@ -294,11 +460,18 @@ public class Rpc implements Closeable {
     return channel;
   }
 
+  public void unRegisterRpc() {
+    if (dispatcher != null) {
+      dispatcher.unregisterRpc(channel);
+    }
+  }
+
   void setDispatcher(RpcDispatcher dispatcher) {
     Utils.checkNotNull(dispatcher);
     Utils.checkState(this.dispatcher == null, "Dispatcher already set.");
     this.dispatcher = dispatcher;
     channel.pipeline().addLast("dispatcher", dispatcher);
+    dispatcher.registerRpc(channel, this);
   }
 
   @Override
diff --git a/rsc/src/main/java/org/apache/livy/rsc/rpc/RpcDispatcher.java 
b/rsc/src/main/java/org/apache/livy/rsc/rpc/RpcDispatcher.java
index 0c149b0..88744c2 100644
--- a/rsc/src/main/java/org/apache/livy/rsc/rpc/RpcDispatcher.java
+++ b/rsc/src/main/java/org/apache/livy/rsc/rpc/RpcDispatcher.java
@@ -17,22 +17,15 @@
 
 package org.apache.livy.rsc.rpc;
 
-import java.lang.reflect.InvocationTargetException;
-import java.lang.reflect.Method;
-import java.util.Collection;
-import java.util.Iterator;
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ConcurrentLinkedQueue;
 
+import io.netty.channel.Channel;
 import io.netty.channel.ChannelHandlerContext;
 import io.netty.channel.SimpleChannelInboundHandler;
-import io.netty.util.concurrent.Promise;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import org.apache.livy.rsc.Utils;
-
 /**
  * An implementation of ChannelInboundHandler that dispatches incoming 
messages to an instance
  * method based on the method signature.
@@ -49,10 +42,7 @@ public abstract class RpcDispatcher extends 
SimpleChannelInboundHandler<Object>
 
   private static final Logger LOG = 
LoggerFactory.getLogger(RpcDispatcher.class);
 
-  private final Map<Class<?>, Method> handlers = new ConcurrentHashMap<>();
-  private final Collection<OutstandingRpc> rpcs = new 
ConcurrentLinkedQueue<OutstandingRpc>();
-
-  private volatile Rpc.MessageHeader lastHeader;
+  private final Map<Channel, Rpc> channelRpc = new ConcurrentHashMap<>();
 
   /**
    * Override this to add a name to the dispatcher, for debugging purposes.
@@ -62,161 +52,36 @@ public abstract class RpcDispatcher extends 
SimpleChannelInboundHandler<Object>
     return getClass().getSimpleName();
   }
 
-  @Override
-  protected final void channelRead0(ChannelHandlerContext ctx, Object msg) 
throws Exception {
-    if (lastHeader == null) {
-      if (!(msg instanceof Rpc.MessageHeader)) {
-        LOG.warn("[{}] Expected RPC header, got {} instead.", name(),
-            msg != null ? msg.getClass().getName() : null);
-        throw new IllegalArgumentException();
-      }
-      lastHeader = (Rpc.MessageHeader) msg;
-    } else {
-      LOG.debug("[{}] Received RPC message: type={} id={} payload={}", name(),
-        lastHeader.type, lastHeader.id, msg != null ? msg.getClass().getName() 
: null);
-      try {
-        switch (lastHeader.type) {
-        case CALL:
-          handleCall(ctx, msg);
-          break;
-        case REPLY:
-          handleReply(ctx, msg, findRpc(lastHeader.id));
-          break;
-        case ERROR:
-          handleError(ctx, msg, findRpc(lastHeader.id));
-          break;
-        default:
-          throw new IllegalArgumentException("Unknown RPC message type: " + 
lastHeader.type);
-        }
-      } finally {
-        lastHeader = null;
-      }
-    }
+  public void registerRpc(Channel channel, Rpc rpc) {
+    channelRpc.put(channel, rpc);
   }
 
-  private OutstandingRpc findRpc(long id) {
-    for (Iterator<OutstandingRpc> it = rpcs.iterator(); it.hasNext();) {
-      OutstandingRpc rpc = it.next();
-      if (rpc.id == id) {
-        it.remove();
-        return rpc;
-      }
-    }
-    throw new IllegalArgumentException(String.format(
-        "Received RPC reply for unknown RPC (%d).", id));
+  public void unregisterRpc(Channel channel) {
+    channelRpc.remove(channel);
   }
 
-  private void handleCall(ChannelHandlerContext ctx, Object msg) throws 
Exception {
-    Method handler = handlers.get(msg.getClass());
-    if (handler == null) {
-      // Try both getDeclaredMethod() and getMethod() so that we try both 
private methods
-      // of the class, and public methods of parent classes.
-      try {
-        handler = getClass().getDeclaredMethod("handle", 
ChannelHandlerContext.class,
-            msg.getClass());
-      } catch (NoSuchMethodException e) {
-        try {
-          handler = getClass().getMethod("handle", ChannelHandlerContext.class,
-              msg.getClass());
-        } catch (NoSuchMethodException e2) {
-          LOG.warn(String.format("[%s] Failed to find handler for msg '%s'.", 
name(),
-            msg.getClass().getName()));
-          writeMessage(ctx, Rpc.MessageType.ERROR, 
Utils.stackTraceAsString(e.getCause()));
-          return;
-        }
-      }
-      handler.setAccessible(true);
-      handlers.put(msg.getClass(), handler);
-    }
-
-    try {
-      Object payload = handler.invoke(this, ctx, msg);
-      if (payload == null) {
-        payload = new Rpc.NullMessage();
-      }
-      writeMessage(ctx, Rpc.MessageType.REPLY, payload);
-    } catch (InvocationTargetException ite) {
-      LOG.debug(String.format("[%s] Error in RPC handler.", name()), 
ite.getCause());
-      writeMessage(ctx, Rpc.MessageType.ERROR, 
Utils.stackTraceAsString(ite.getCause()));
+  private Rpc getRpc(ChannelHandlerContext ctx) {
+    Channel channel = ctx.channel();
+    if (!channelRpc.containsKey(channel)) {
+      throw new IllegalArgumentException("not existed channel:" + channel);
     }
-  }
-
-  private void writeMessage(ChannelHandlerContext ctx, Rpc.MessageType 
replyType, Object payload) {
-    ctx.channel().write(new Rpc.MessageHeader(lastHeader.id, replyType));
-    ctx.channel().writeAndFlush(payload);
-  }
 
-  private void handleReply(ChannelHandlerContext ctx, Object msg, 
OutstandingRpc rpc)
-      throws Exception {
-    rpc.future.setSuccess(msg instanceof Rpc.NullMessage ? null : msg);
+    return channelRpc.get(channel);
   }
 
-  private void handleError(ChannelHandlerContext ctx, Object msg, 
OutstandingRpc rpc)
-      throws Exception {
-    if (msg instanceof String) {
-      LOG.warn("Received error message:{}.", msg);
-      rpc.future.setFailure(new RpcException((String) msg));
-    } else {
-      String error = String.format("Received error with unexpected payload 
(%s).",
-          msg != null ? msg.getClass().getName() : null);
-      LOG.warn(String.format("[%s] %s", name(), error));
-      rpc.future.setFailure(new IllegalArgumentException(error));
-      ctx.close();
-    }
+  @Override
+  protected final void channelRead0(ChannelHandlerContext ctx, Object msg) 
throws Exception {
+    getRpc(ctx).handleMsg(ctx, msg, getClass(), this);
   }
 
   @Override
   public final void exceptionCaught(ChannelHandlerContext ctx, Throwable 
cause) {
-    if (LOG.isDebugEnabled()) {
-      LOG.debug(String.format("[%s] Caught exception in channel pipeline.", 
name()), cause);
-    } else {
-      LOG.info("[{}] Closing channel due to exception in pipeline ({}).", 
name(),
-          cause.getMessage());
-    }
-
-    if (lastHeader != null) {
-      // There's an RPC waiting for a reply. Exception was most probably 
caught while processing
-      // the RPC, so send an error.
-      ctx.channel().write(new Rpc.MessageHeader(lastHeader.id, 
Rpc.MessageType.ERROR));
-      ctx.channel().writeAndFlush(Utils.stackTraceAsString(cause));
-      lastHeader = null;
-    }
-
-    ctx.close();
+    getRpc(ctx).handleChannelException(ctx, cause);
   }
 
   @Override
   public final void channelInactive(ChannelHandlerContext ctx) throws 
Exception {
-    if (rpcs.size() > 0) {
-      LOG.warn("[{}] Closing RPC channel with {} outstanding RPCs.", name(), 
rpcs.size());
-      for (OutstandingRpc rpc : rpcs) {
-        rpc.future.cancel(true);
-      }
-    } else {
-      LOG.debug("Channel {} became inactive.", ctx.channel());
-    }
+    getRpc(ctx).handleChannelInactive();
     super.channelInactive(ctx);
   }
-
-  void registerRpc(long id, Promise<?> promise, String type) {
-    LOG.debug("[{}] Registered outstanding rpc {} ({}).", name(), id, type);
-    rpcs.add(new OutstandingRpc(id, promise));
-  }
-
-  void discardRpc(long id) {
-    LOG.debug("[{}] Discarding failed RPC {}.", name(), id);
-    findRpc(id);
-  }
-
-  private static class OutstandingRpc {
-    final long id;
-    final Promise<Object> future;
-
-    @SuppressWarnings("unchecked")
-    OutstandingRpc(long id, Promise<?> future) {
-      this.id = id;
-      this.future = (Promise<Object>) future;
-    }
-  }
-
 }

Reply via email to