[FLINK-4383] [rpc] Eagerly serialize remote rpc invocation messages

This PR introduces an eager serialization for remote rpc invocation messages.
That way it is possible to check whether the message is serializable and
whether it exceeds the maximum allowed akka frame size. If either of these
constraints is violated, a proper exception is thrown instead of simply
swallowing the exception as Akka does it.

Address PR comments

This closes #2365.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/433a1fd0
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/433a1fd0
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/433a1fd0

Branch: refs/heads/flip-6
Commit: 433a1fd0364cff2c73d81629dcb470743dea84ae
Parents: 4ca049b
Author: Till Rohrmann <trohrm...@apache.org>
Authored: Fri Aug 12 10:32:30 2016 +0200
Committer: Till Rohrmann <trohrm...@apache.org>
Committed: Wed Sep 21 11:39:13 2016 +0200

----------------------------------------------------------------------
 .../flink/runtime/rpc/akka/AkkaGateway.java     |   2 +-
 .../runtime/rpc/akka/AkkaInvocationHandler.java |  83 ++++++--
 .../flink/runtime/rpc/akka/AkkaRpcActor.java    |  26 ++-
 .../flink/runtime/rpc/akka/AkkaRpcService.java  |  20 +-
 .../rpc/akka/messages/LocalRpcInvocation.java   |  54 +++++
 .../rpc/akka/messages/RemoteRpcInvocation.java  | 206 ++++++++++++++++++
 .../rpc/akka/messages/RpcInvocation.java        | 106 +++-------
 .../runtime/rpc/akka/AkkaRpcServiceTest.java    |   2 +-
 .../rpc/akka/MessageSerializationTest.java      | 210 +++++++++++++++++++
 9 files changed, 597 insertions(+), 112 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/433a1fd0/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaGateway.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaGateway.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaGateway.java
index ec3091c..f6125dc 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaGateway.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaGateway.java
@@ -26,5 +26,5 @@ import org.apache.flink.runtime.rpc.RpcGateway;
  */
 interface AkkaGateway extends RpcGateway {
 
-       ActorRef getRpcServer();
+       ActorRef getRpcEndpoint();
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/433a1fd0/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaInvocationHandler.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaInvocationHandler.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaInvocationHandler.java
index 580b161..297104b 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaInvocationHandler.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaInvocationHandler.java
@@ -25,13 +25,17 @@ import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.runtime.rpc.MainThreadExecutor;
 import org.apache.flink.runtime.rpc.RpcTimeout;
 import org.apache.flink.runtime.rpc.akka.messages.CallAsync;
+import org.apache.flink.runtime.rpc.akka.messages.LocalRpcInvocation;
+import org.apache.flink.runtime.rpc.akka.messages.RemoteRpcInvocation;
 import org.apache.flink.runtime.rpc.akka.messages.RpcInvocation;
 import org.apache.flink.runtime.rpc.akka.messages.RunAsync;
 import org.apache.flink.util.Preconditions;
+import org.apache.log4j.Logger;
 import scala.concurrent.Await;
 import scala.concurrent.Future;
 import scala.concurrent.duration.FiniteDuration;
 
+import java.io.IOException;
 import java.lang.annotation.Annotation;
 import java.lang.reflect.InvocationHandler;
 import java.lang.reflect.Method;
@@ -42,19 +46,28 @@ import static 
org.apache.flink.util.Preconditions.checkNotNull;
 import static org.apache.flink.util.Preconditions.checkArgument;
 
 /**
- * Invocation handler to be used with a {@link AkkaRpcActor}. The invocation 
handler wraps the
- * rpc in a {@link RpcInvocation} message and then sends it to the {@link 
AkkaRpcActor} where it is
+ * Invocation handler to be used with an {@link AkkaRpcActor}. The invocation 
handler wraps the
+ * rpc in a {@link LocalRpcInvocation} message and then sends it to the {@link 
AkkaRpcActor} where it is
  * executed.
  */
 class AkkaInvocationHandler implements InvocationHandler, AkkaGateway, 
MainThreadExecutor {
-       private final ActorRef rpcServer;
+       private static final Logger LOG = 
Logger.getLogger(AkkaInvocationHandler.class);
+
+       private final ActorRef rpcEndpoint;
+
+       // whether the actor ref is local and thus no message serialization is 
needed
+       private final boolean isLocal;
 
        // default timeout for asks
        private final Timeout timeout;
 
-       AkkaInvocationHandler(ActorRef rpcServer, Timeout timeout) {
-               this.rpcServer = Preconditions.checkNotNull(rpcServer);
+       private final long maximumFramesize;
+
+       AkkaInvocationHandler(ActorRef rpcEndpoint, Timeout timeout, long 
maximumFramesize) {
+               this.rpcEndpoint = Preconditions.checkNotNull(rpcEndpoint);
+               this.isLocal = 
this.rpcEndpoint.path().address().hasLocalScope();
                this.timeout = Preconditions.checkNotNull(timeout);
+               this.maximumFramesize = maximumFramesize;
        }
 
        @Override
@@ -76,23 +89,43 @@ class AkkaInvocationHandler implements InvocationHandler, 
AkkaGateway, MainThrea
                                parameterAnnotations,
                                args);
 
-                       RpcInvocation rpcInvocation = new RpcInvocation(
-                               methodName,
-                               filteredArguments.f0,
-                               filteredArguments.f1);
+                       RpcInvocation rpcInvocation;
+
+                       if (isLocal) {
+                               rpcInvocation = new LocalRpcInvocation(
+                                       methodName,
+                                       filteredArguments.f0,
+                                       filteredArguments.f1);
+                       } else {
+                               try {
+                                       RemoteRpcInvocation remoteRpcInvocation 
= new RemoteRpcInvocation(
+                                               methodName,
+                                               filteredArguments.f0,
+                                               filteredArguments.f1);
+
+                                       if (remoteRpcInvocation.getSize() > 
maximumFramesize) {
+                                               throw new IOException("The rpc 
invocation size exceeds the maximum akka framesize.");
+                                       } else {
+                                               rpcInvocation = 
remoteRpcInvocation;
+                                       }
+                               } catch (IOException e) {
+                                       LOG.warn("Could not create remote rpc 
invocation message. Failing rpc invocation because...", e);
+                                       throw e;
+                               }
+                       }
 
                        Class<?> returnType = method.getReturnType();
 
                        if (returnType.equals(Void.TYPE)) {
-                               rpcServer.tell(rpcInvocation, 
ActorRef.noSender());
+                               rpcEndpoint.tell(rpcInvocation, 
ActorRef.noSender());
 
                                result = null;
                        } else if (returnType.equals(Future.class)) {
                                // execute an asynchronous call
-                               result = Patterns.ask(rpcServer, rpcInvocation, 
futureTimeout);
+                               result = Patterns.ask(rpcEndpoint, 
rpcInvocation, futureTimeout);
                        } else {
                                // execute a synchronous call
-                               Future<?> futureResult = 
Patterns.ask(rpcServer, rpcInvocation, futureTimeout);
+                               Future<?> futureResult = 
Patterns.ask(rpcEndpoint, rpcInvocation, futureTimeout);
                                FiniteDuration duration = timeout.duration();
 
                                result = Await.result(futureResult, duration);
@@ -103,8 +136,8 @@ class AkkaInvocationHandler implements InvocationHandler, 
AkkaGateway, MainThrea
        }
 
        @Override
-       public ActorRef getRpcServer() {
-               return rpcServer;
+       public ActorRef getRpcEndpoint() {
+               return rpcEndpoint;
        }
 
        @Override
@@ -117,19 +150,25 @@ class AkkaInvocationHandler implements InvocationHandler, 
AkkaGateway, MainThrea
                checkNotNull(runnable, "runnable");
                checkArgument(delay >= 0, "delay must be zero or greater");
                
-               // Unfortunately I couldn't find a way to allow only local 
communication. Therefore, the
-               // runnable field is transient transient
-               rpcServer.tell(new RunAsync(runnable, delay), 
ActorRef.noSender());
+               if (isLocal) {
+                       rpcEndpoint.tell(new RunAsync(runnable, delay), 
ActorRef.noSender());
+               } else {
+                       throw new RuntimeException("Trying to send a Runnable 
to a remote actor at " +
+                               rpcEndpoint.path() + ". This is not 
supported.");
+               }
        }
 
        @Override
        public <V> Future<V> callAsync(Callable<V> callable, Timeout 
callTimeout) {
-               // Unfortunately I couldn't find a way to allow only local 
communication. Therefore, the
-               // callable field is declared transient
-               @SuppressWarnings("unchecked")
-               Future<V> result = (Future<V>) Patterns.ask(rpcServer, new 
CallAsync(callable), callTimeout);
+               if(isLocal) {
+                       @SuppressWarnings("unchecked")
+                       Future<V> result = (Future<V>) 
Patterns.ask(rpcEndpoint, new CallAsync(callable), callTimeout);
 
-               return result;
+                       return result;
+               } else {
+                       throw new RuntimeException("Trying to send a Callable 
to a remote actor at " +
+                               rpcEndpoint.path() + ". This is not 
supported.");
+               }
        }
 
        /**

http://git-wip-us.apache.org/repos/asf/flink/blob/433a1fd0/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActor.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActor.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActor.java
index 5e0a7da..dfcbcc3 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActor.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActor.java
@@ -26,6 +26,7 @@ import org.apache.flink.runtime.rpc.MainThreadValidatorUtil;
 import org.apache.flink.runtime.rpc.RpcEndpoint;
 import org.apache.flink.runtime.rpc.RpcGateway;
 import org.apache.flink.runtime.rpc.akka.messages.CallAsync;
+import org.apache.flink.runtime.rpc.akka.messages.LocalRpcInvocation;
 import org.apache.flink.runtime.rpc.akka.messages.RpcInvocation;
 import org.apache.flink.runtime.rpc.akka.messages.RunAsync;
 
@@ -35,6 +36,7 @@ import org.slf4j.LoggerFactory;
 import scala.concurrent.Future;
 import scala.concurrent.duration.FiniteDuration;
 
+import java.io.IOException;
 import java.lang.reflect.Method;
 import java.util.concurrent.Callable;
 import java.util.concurrent.TimeUnit;
@@ -42,10 +44,10 @@ import java.util.concurrent.TimeUnit;
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
 /**
- * Akka rpc actor which receives {@link RpcInvocation}, {@link RunAsync} and 
{@link CallAsync}
+ * Akka rpc actor which receives {@link LocalRpcInvocation}, {@link RunAsync} 
and {@link CallAsync}
  * messages.
  * <p>
- * The {@link RpcInvocation} designates a rpc and is dispatched to the given 
{@link RpcEndpoint}
+ * The {@link LocalRpcInvocation} designates a rpc and is dispatched to the 
given {@link RpcEndpoint}
  * instance.
  * <p>
  * The {@link RunAsync} and {@link CallAsync} messages contain executable code 
which is executed
@@ -95,15 +97,12 @@ class AkkaRpcActor<C extends RpcGateway, T extends 
RpcEndpoint<C>> extends Untyp
         * @param rpcInvocation Rpc invocation message
         */
        private void handleRpcInvocation(RpcInvocation rpcInvocation) {
-               Method rpcMethod = null;
-
                try {
-                       rpcMethod = 
lookupRpcMethod(rpcInvocation.getMethodName(), 
rpcInvocation.getParameterTypes());
-               } catch (final NoSuchMethodException e) {
-                       LOG.error("Could not find rpc method for rpc 
invocation: {}.", rpcInvocation, e);
-               }
+                       String methodName = rpcInvocation.getMethodName();
+                       Class<?>[] parameterTypes = 
rpcInvocation.getParameterTypes();
+
+                       Method rpcMethod = lookupRpcMethod(methodName, 
parameterTypes);
 
-               if (rpcMethod != null) {
                        if (rpcMethod.getReturnType().equals(Void.TYPE)) {
                                // No return value to send back
                                try {
@@ -127,6 +126,12 @@ class AkkaRpcActor<C extends RpcGateway, T extends 
RpcEndpoint<C>> extends Untyp
                                        getSender().tell(new Status.Failure(e), 
getSelf());
                                }
                        }
+               } catch(ClassNotFoundException e) {
+                       LOG.error("Could not load method arguments.", e);
+               } catch (IOException e) {
+                       LOG.error("Could not deserialize rpc invocation 
message.", e);
+               } catch (final NoSuchMethodException e) {
+                       LOG.error("Could not find rpc method for rpc 
invocation: {}.", rpcInvocation, e);
                }
        }
 
@@ -195,7 +200,8 @@ class AkkaRpcActor<C extends RpcGateway, T extends 
RpcEndpoint<C>> extends Untyp
         * @param methodName Name of the method
         * @param parameterTypes Parameter types of the method
         * @return Method of the rpc endpoint
-        * @throws NoSuchMethodException
+        * @throws NoSuchMethodException Thrown if the method with the given 
name and parameter types
+        *                                                                      
cannot be found at the rpc endpoint
         */
        private Method lookupRpcMethod(final String methodName, final 
Class<?>[] parameterTypes) throws NoSuchMethodException {
                return rpcEndpoint.getClass().getMethod(methodName, 
parameterTypes);

http://git-wip-us.apache.org/repos/asf/flink/blob/433a1fd0/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcService.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcService.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcService.java
index db40f10..b963c53 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcService.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcService.java
@@ -58,17 +58,27 @@ public class AkkaRpcService implements RpcService {
 
        private static final Logger LOG = 
LoggerFactory.getLogger(AkkaRpcService.class);
 
+       static final String MAXIMUM_FRAME_SIZE_PATH = 
"akka.remote.netty.tcp.maximum-frame-size";
+
        private final Object lock = new Object();
 
        private final ActorSystem actorSystem;
        private final Timeout timeout;
        private final Set<ActorRef> actors = new HashSet<>(4);
+       private final long maximumFramesize;
 
        private volatile boolean stopped;
 
        public AkkaRpcService(final ActorSystem actorSystem, final Timeout 
timeout) {
                this.actorSystem = checkNotNull(actorSystem, "actor system");
                this.timeout = checkNotNull(timeout, "timeout");
+
+               if 
(actorSystem.settings().config().hasPath(MAXIMUM_FRAME_SIZE_PATH)) {
+                       maximumFramesize = 
actorSystem.settings().config().getBytes(MAXIMUM_FRAME_SIZE_PATH);
+               } else {
+                       // only local communication
+                       maximumFramesize = Long.MAX_VALUE;
+               }
        }
 
        // this method does not mutate state and is thus thread-safe
@@ -88,7 +98,7 @@ public class AkkaRpcService implements RpcService {
                        public C apply(Object obj) {
                                ActorRef actorRef = ((ActorIdentity) 
obj).getRef();
 
-                               InvocationHandler akkaInvocationHandler = new 
AkkaInvocationHandler(actorRef, timeout);
+                               InvocationHandler akkaInvocationHandler = new 
AkkaInvocationHandler(actorRef, timeout, maximumFramesize);
 
                                @SuppressWarnings("unchecked")
                                C proxy = (C) Proxy.newProxyInstance(
@@ -116,7 +126,7 @@ public class AkkaRpcService implements RpcService {
 
                LOG.info("Starting RPC endpoint for {} at {} .", 
rpcEndpoint.getClass().getName(), actorRef.path());
 
-               InvocationHandler akkaInvocationHandler = new 
AkkaInvocationHandler(actorRef, timeout);
+               InvocationHandler akkaInvocationHandler = new 
AkkaInvocationHandler(actorRef, timeout, maximumFramesize);
 
                // Rather than using the System ClassLoader directly, we derive 
the ClassLoader
                // from this class . That works better in cases where Flink 
runs embedded and all Flink
@@ -142,12 +152,12 @@ public class AkkaRpcService implements RpcService {
                                if (stopped) {
                                        return;
                                } else {
-                                       fromThisService = 
actors.remove(akkaClient.getRpcServer());
+                                       fromThisService = 
actors.remove(akkaClient.getRpcEndpoint());
                                }
                        }
 
                        if (fromThisService) {
-                               ActorRef selfActorRef = 
akkaClient.getRpcServer();
+                               ActorRef selfActorRef = 
akkaClient.getRpcEndpoint();
                                LOG.info("Stopping RPC endpoint {}.", 
selfActorRef.path());
                                selfActorRef.tell(PoisonPill.getInstance(), 
ActorRef.noSender());
                        } else {
@@ -178,7 +188,7 @@ public class AkkaRpcService implements RpcService {
                checkState(!stopped, "RpcService is stopped");
 
                if (selfGateway instanceof AkkaGateway) {
-                       ActorRef actorRef = ((AkkaGateway) 
selfGateway).getRpcServer();
+                       ActorRef actorRef = ((AkkaGateway) 
selfGateway).getRpcEndpoint();
                        return AkkaUtils.getAkkaURL(actorSystem, actorRef);
                } else {
                        String className = AkkaGateway.class.getName();

http://git-wip-us.apache.org/repos/asf/flink/blob/433a1fd0/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/LocalRpcInvocation.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/LocalRpcInvocation.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/LocalRpcInvocation.java
new file mode 100644
index 0000000..97c10d7
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/LocalRpcInvocation.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.rpc.akka.messages;
+
+import org.apache.flink.util.Preconditions;
+
+/**
+ * Local rpc invocation message containing the remote procedure name, its 
parameter types and the
+ * corresponding call arguments. This message will only be sent if the 
communication is local and,
+ * thus, the message does not have to be serialized.
+ */
+public final class LocalRpcInvocation implements RpcInvocation {
+
+       private final String methodName;
+       private final Class<?>[] parameterTypes;
+       private final Object[] args;
+
+       public LocalRpcInvocation(String methodName, Class<?>[] parameterTypes, 
Object[] args) {
+               this.methodName = Preconditions.checkNotNull(methodName);
+               this.parameterTypes = 
Preconditions.checkNotNull(parameterTypes);
+               this.args = args;
+       }
+
+       @Override
+       public String getMethodName() {
+               return methodName;
+       }
+
+       @Override
+       public Class<?>[] getParameterTypes() {
+               return parameterTypes;
+       }
+
+       @Override
+       public Object[] getArgs() {
+               return args;
+       }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/433a1fd0/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/RemoteRpcInvocation.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/RemoteRpcInvocation.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/RemoteRpcInvocation.java
new file mode 100644
index 0000000..bc26a29
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/RemoteRpcInvocation.java
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.rpc.akka.messages;
+
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.SerializedValue;
+
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.io.Serializable;
+
+/**
+ * Remote rpc invocation message which is used when the actor communication is 
remote and, thus, the
+ * message has to be serialized.
+ * <p>
+ * In order to fail fast and report an appropriate error message to the user, 
the method name, the
+ * parameter types and the arguments are eagerly serialized. In case the the 
invocation call
+ * contains a non-serializable object, then an {@link IOException} is thrown.
+ */
+public class RemoteRpcInvocation implements RpcInvocation, Serializable {
+       private static final long serialVersionUID = 6179354390913843809L;
+
+       // Serialized invocation data
+       private SerializedValue<RemoteRpcInvocation.MethodInvocation> 
serializedMethodInvocation;
+
+       // Transient field which is lazily initialized upon first access to the 
invocation data
+       private transient RemoteRpcInvocation.MethodInvocation methodInvocation;
+
+       public  RemoteRpcInvocation(
+               final String methodName,
+               final Class<?>[] parameterTypes,
+               final Object[] args) throws IOException {
+
+               serializedMethodInvocation = new SerializedValue<>(new 
RemoteRpcInvocation.MethodInvocation(methodName, parameterTypes, args));
+               methodInvocation = null;
+       }
+
+       @Override
+       public String getMethodName() throws IOException, 
ClassNotFoundException {
+               deserializeMethodInvocation();
+
+               return methodInvocation.getMethodName();
+       }
+
+       @Override
+       public Class<?>[] getParameterTypes() throws IOException, 
ClassNotFoundException {
+               deserializeMethodInvocation();
+
+               return methodInvocation.getParameterTypes();
+       }
+
+       @Override
+       public Object[] getArgs() throws IOException, ClassNotFoundException {
+               deserializeMethodInvocation();
+
+               return methodInvocation.getArgs();
+       }
+
+       /**
+        * Size (#bytes of the serialized data) of the rpc invocation message.
+        *
+        * @return Size of the remote rpc invocation message
+        */
+       public long getSize() {
+               return serializedMethodInvocation.getByteArray().length;
+       }
+
+       private void deserializeMethodInvocation() throws IOException, 
ClassNotFoundException {
+               if (methodInvocation == null) {
+                       methodInvocation = 
serializedMethodInvocation.deserializeValue(ClassLoader.getSystemClassLoader());
+               }
+       }
+
+       // -------------------------------------------------------------------
+       // Serialization methods
+       // -------------------------------------------------------------------
+
+       private void writeObject(ObjectOutputStream oos) throws IOException {
+               oos.writeObject(serializedMethodInvocation);
+       }
+
+       @SuppressWarnings("unchecked")
+       private void readObject(ObjectInputStream ois) throws IOException, 
ClassNotFoundException {
+               serializedMethodInvocation = 
(SerializedValue<RemoteRpcInvocation.MethodInvocation>) ois.readObject();
+               methodInvocation = null;
+       }
+
+       // -------------------------------------------------------------------
+       // Utility classes
+       // -------------------------------------------------------------------
+
+       /**
+        * Wrapper class for the method invocation information
+        */
+       private static final class MethodInvocation implements Serializable {
+               private static final long serialVersionUID = 
9187962608946082519L;
+
+               private String methodName;
+               private Class<?>[] parameterTypes;
+               private Object[] args;
+
+               private MethodInvocation(final String methodName, final 
Class<?>[] parameterTypes, final Object[] args) {
+                       this.methodName = methodName;
+                       this.parameterTypes = 
Preconditions.checkNotNull(parameterTypes);
+                       this.args = args;
+               }
+
+               String getMethodName() {
+                       return methodName;
+               }
+
+               Class<?>[] getParameterTypes() {
+                       return parameterTypes;
+               }
+
+               Object[] getArgs() {
+                       return args;
+               }
+
+               private void writeObject(ObjectOutputStream oos) throws 
IOException {
+                       oos.writeUTF(methodName);
+
+                       oos.writeInt(parameterTypes.length);
+
+                       for (Class<?> parameterType : parameterTypes) {
+                               oos.writeObject(parameterType);
+                       }
+
+                       if (args != null) {
+                               oos.writeBoolean(true);
+
+                               for (int i = 0; i < args.length; i++) {
+                                       try {
+                                               oos.writeObject(args[i]);
+                                       } catch (IOException e) {
+                                               throw new IOException("Could 
not serialize " + i + "th argument of method " +
+                                                       methodName + ". This 
indicates that the argument type " +
+                                                       
args.getClass().getName() + " is not serializable. Arguments have to " +
+                                                       "be serializable for 
remote rpc calls.", e);
+                                       }
+                               }
+                       } else {
+                               oos.writeBoolean(false);
+                       }
+               }
+
+               private void readObject(ObjectInputStream ois) throws 
IOException, ClassNotFoundException {
+                       methodName = ois.readUTF();
+
+                       int length = ois.readInt();
+
+                       parameterTypes = new Class<?>[length];
+
+                       for (int i = 0; i < length; i++) {
+                               try {
+                                       parameterTypes[i] = (Class<?>) 
ois.readObject();
+                               } catch (IOException e) {
+                                       throw new IOException("Could not 
deserialize " + i + "th parameter type of method " +
+                                               methodName + '.', e);
+                               } catch (ClassNotFoundException e) {
+                                       throw new ClassNotFoundException("Could 
not deserialize " + i + "th " +
+                                               "parameter type of method " + 
methodName + ". This indicates that the parameter " +
+                                               "type is not part of the system 
class loader.", e);
+                               }
+                       }
+
+                       boolean hasArgs = ois.readBoolean();
+
+                       if (hasArgs) {
+                               args = new Object[length];
+
+                               for (int i = 0; i < length; i++) {
+                                       try {
+                                               args[i] = ois.readObject();
+                                       } catch (IOException e) {
+                                               throw new IOException("Could 
not deserialize " + i + "th argument of method " +
+                                                       methodName + '.', e);
+                                       } catch (ClassNotFoundException e) {
+                                               throw new 
ClassNotFoundException("Could not deserialize " + i + "th " +
+                                                       "argument of method " + 
methodName + ". This indicates that the argument " +
+                                                       "type is not part of 
the system class loader.", e);
+                                       }
+                               }
+                       } else {
+                               args = null;
+                       }
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/433a1fd0/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/RpcInvocation.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/RpcInvocation.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/RpcInvocation.java
index 5d52ef1..b174c99 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/RpcInvocation.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/RpcInvocation.java
@@ -18,81 +18,41 @@
 
 package org.apache.flink.runtime.rpc.akka.messages;
 
-import org.apache.flink.util.Preconditions;
-
 import java.io.IOException;
-import java.io.ObjectInputStream;
-import java.io.ObjectOutputStream;
-import java.io.Serializable;
 
 /**
- * Rpc invocation message containing the remote procedure name, its parameter 
types and the
- * corresponding call arguments.
+ * Interface for rpc invocation messages. The interface allows to request all 
necessary information
+ * to lookup a method and call it with the corresponding arguments.
  */
-public final class RpcInvocation implements Serializable {
-       private static final long serialVersionUID = -7058254033460536037L;
-
-       private final String methodName;
-       private final Class<?>[] parameterTypes;
-       private transient Object[] args;
-
-       public RpcInvocation(String methodName, Class<?>[] parameterTypes, 
Object[] args) {
-               this.methodName = Preconditions.checkNotNull(methodName);
-               this.parameterTypes = 
Preconditions.checkNotNull(parameterTypes);
-               this.args = args;
-       }
-
-       public String getMethodName() {
-               return methodName;
-       }
-
-       public Class<?>[] getParameterTypes() {
-               return parameterTypes;
-       }
-
-       public Object[] getArgs() {
-               return args;
-       }
-
-       private void writeObject(ObjectOutputStream oos) throws IOException {
-               oos.defaultWriteObject();
-
-               if (args != null) {
-                       // write has args true
-                       oos.writeBoolean(true);
-
-                       for (int i = 0; i < args.length; i++) {
-                               try {
-                                       oos.writeObject(args[i]);
-                               } catch (IOException e) {
-                                       Class<?> argClass = args[i].getClass();
-
-                                       throw new IOException("Could not write 
" + i + "th argument of method " +
-                                               methodName + ". The argument 
type is " + argClass + ". " +
-                                               "Make sure that this type is 
serializable.", e);
-                               }
-                       }
-               } else {
-                       // write has args false
-                       oos.writeBoolean(false);
-               }
-       }
-
-       private void readObject(ObjectInputStream ois) throws IOException, 
ClassNotFoundException {
-               ois.defaultReadObject();
-
-               boolean hasArgs = ois.readBoolean();
-
-               if (hasArgs) {
-                       int numberArguments = parameterTypes.length;
-
-                       args = new Object[numberArguments];
-
-                       for (int i = 0; i < numberArguments; i++) {
-                               args[i] = ois.readObject();
-                       }
-               } else {
-                       args = null;
-               }
-       }
+public interface RpcInvocation {
+
+       /**
+        * Returns the method's name.
+        *
+        * @return Method name
+        * @throws IOException if the rpc invocation message is a remote 
message and could not be deserialized
+        * @throws ClassNotFoundException if the rpc invocation message is a 
remote message and contains
+        * serialized classes which cannot be found on the receiving side
+        */
+       String getMethodName() throws IOException, ClassNotFoundException;
+
+       /**
+        * Returns the method's parameter types
+        *
+        * @return Method's parameter types
+        * @throws IOException if the rpc invocation message is a remote 
message and could not be deserialized
+        * @throws ClassNotFoundException if the rpc invocation message is a 
remote message and contains
+        * serialized classes which cannot be found on the receiving side
+        */
+       Class<?>[] getParameterTypes() throws IOException, 
ClassNotFoundException;
+
+       /**
+        * Returns the arguments of the remote procedure call
+        *
+        * @return Arguments of the remote procedure call
+        * @throws IOException if the rpc invocation message is a remote 
message and could not be deserialized
+        * @throws ClassNotFoundException if the rpc invocation message is a 
remote message and contains
+        * serialized classes which cannot be found on the receiving side
+        */
+       Object[] getArgs() throws IOException, ClassNotFoundException;
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/433a1fd0/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/AkkaRpcServiceTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/AkkaRpcServiceTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/AkkaRpcServiceTest.java
index 5e37e10..f26b40b 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/AkkaRpcServiceTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/AkkaRpcServiceTest.java
@@ -64,7 +64,7 @@ public class AkkaRpcServiceTest extends TestLogger {
                AkkaGateway akkaClient = (AkkaGateway) rm;
 
                
-               
jobMaster.registerAtResourceManager(AkkaUtils.getAkkaURL(actorSystem, 
akkaClient.getRpcServer()));
+               
jobMaster.registerAtResourceManager(AkkaUtils.getAkkaURL(actorSystem, 
akkaClient.getRpcEndpoint()));
 
                // wait for successful registration
                FiniteDuration timeout = new FiniteDuration(200, 
TimeUnit.SECONDS);

http://git-wip-us.apache.org/repos/asf/flink/blob/433a1fd0/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/MessageSerializationTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/MessageSerializationTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/MessageSerializationTest.java
new file mode 100644
index 0000000..ca8179c
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/MessageSerializationTest.java
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.rpc.akka;
+
+import akka.actor.ActorSystem;
+import akka.util.Timeout;
+import com.typesafe.config.Config;
+import com.typesafe.config.ConfigValueFactory;
+import org.apache.flink.runtime.akka.AkkaUtils;
+import org.apache.flink.runtime.rpc.RpcEndpoint;
+import org.apache.flink.runtime.rpc.RpcGateway;
+import org.apache.flink.runtime.rpc.RpcMethod;
+import org.apache.flink.runtime.rpc.RpcService;
+import org.apache.flink.util.TestLogger;
+import org.hamcrest.core.Is;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import scala.concurrent.Await;
+import scala.concurrent.Future;
+import scala.concurrent.duration.FiniteDuration;
+
+import java.io.IOException;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.fail;
+
+/**
+ * Tests that akka rpc invocation messages are properly serialized and errors 
reported
+ */
+public class MessageSerializationTest extends TestLogger {
+       private static ActorSystem actorSystem1;
+       private static ActorSystem actorSystem2;
+       private static AkkaRpcService akkaRpcService1;
+       private static AkkaRpcService akkaRpcService2;
+
+       private static final FiniteDuration timeout = new FiniteDuration(10L, 
TimeUnit.SECONDS);
+       private static final int maxFrameSize = 32000;
+
+       @BeforeClass
+       public static void setup() {
+               Config akkaConfig = AkkaUtils.getDefaultAkkaConfig();
+               Config modifiedAkkaConfig = 
akkaConfig.withValue(AkkaRpcService.MAXIMUM_FRAME_SIZE_PATH, 
ConfigValueFactory.fromAnyRef(maxFrameSize + "b"));
+
+               actorSystem1 = AkkaUtils.createActorSystem(modifiedAkkaConfig);
+               actorSystem2 = AkkaUtils.createActorSystem(modifiedAkkaConfig);
+
+               akkaRpcService1 = new AkkaRpcService(actorSystem1, new 
Timeout(timeout));
+               akkaRpcService2 = new AkkaRpcService(actorSystem2, new 
Timeout(timeout));
+       }
+
+       @AfterClass
+       public static void teardown() {
+               akkaRpcService1.stopService();
+               akkaRpcService2.stopService();
+
+               actorSystem1.shutdown();
+               actorSystem2.shutdown();
+
+               actorSystem1.awaitTermination();
+               actorSystem2.awaitTermination();
+       }
+
+       /**
+        * Tests that a local rpc call with a non serializable argument can be 
executed.
+        */
+       @Test
+       public void testNonSerializableLocalMessageTransfer() throws 
InterruptedException, IOException {
+               LinkedBlockingQueue<Object> linkedBlockingQueue = new 
LinkedBlockingQueue<>();
+               TestEndpoint testEndpoint = new TestEndpoint(akkaRpcService1, 
linkedBlockingQueue);
+
+               TestGateway testGateway = testEndpoint.getSelf();
+
+               NonSerializableObject expected = new NonSerializableObject(42);
+
+               testGateway.foobar(expected);
+
+               assertThat(linkedBlockingQueue.take(), Is.<Object>is(expected));
+       }
+
+       /**
+        * Tests that a remote rpc call with a non-serializable argument fails 
with an
+        * {@link IOException} (or an {@link 
java.lang.reflect.UndeclaredThrowableException} if the
+        * the method declaration does not include the {@link IOException} as 
throwable).
+        */
+       @Test(expected = IOException.class)
+       public void testNonSerializableRemoteMessageTransfer() throws Exception 
{
+               LinkedBlockingQueue<Object> linkedBlockingQueue = new 
LinkedBlockingQueue<>();
+
+               TestEndpoint testEndpoint = new TestEndpoint(akkaRpcService1, 
linkedBlockingQueue);
+
+               String address = testEndpoint.getAddress();
+
+               Future<TestGateway> remoteGatewayFuture = 
akkaRpcService2.connect(address, TestGateway.class);
+
+               TestGateway remoteGateway = Await.result(remoteGatewayFuture, 
timeout);
+
+               remoteGateway.foobar(new Object());
+
+               fail("Should have failed because Object is not serializable.");
+       }
+
+       /**
+        * Tests that a remote rpc call with a serializable argument can be 
successfully executed.
+        */
+       @Test
+       public void testSerializableRemoteMessageTransfer() throws Exception {
+               LinkedBlockingQueue<Object> linkedBlockingQueue = new 
LinkedBlockingQueue<>();
+
+               TestEndpoint testEndpoint = new TestEndpoint(akkaRpcService1, 
linkedBlockingQueue);
+
+               String address = testEndpoint.getAddress();
+
+               Future<TestGateway> remoteGatewayFuture = 
akkaRpcService2.connect(address, TestGateway.class);
+
+               TestGateway remoteGateway = Await.result(remoteGatewayFuture, 
timeout);
+
+               int expected = 42;
+
+               remoteGateway.foobar(expected);
+
+               assertThat(linkedBlockingQueue.take(), Is.<Object>is(expected));
+       }
+
+       /**
+        * Tests that a message which exceeds the maximum frame size is 
detected and a corresponding
+        * exception is thrown.
+        */
+       @Test(expected = IOException.class)
+       public void testMaximumFramesizeRemoteMessageTransfer() throws 
Exception {
+               LinkedBlockingQueue<Object> linkedBlockingQueue = new 
LinkedBlockingQueue<>();
+
+               TestEndpoint testEndpoint = new TestEndpoint(akkaRpcService1, 
linkedBlockingQueue);
+
+               String address = testEndpoint.getAddress();
+
+               Future<TestGateway> remoteGatewayFuture = 
akkaRpcService2.connect(address, TestGateway.class);
+
+               TestGateway remoteGateway = Await.result(remoteGatewayFuture, 
timeout);
+
+               int bufferSize = maxFrameSize + 1;
+               byte[] buffer = new byte[bufferSize];
+
+               remoteGateway.foobar(buffer);
+
+               fail("Should have failed due to exceeding the maximum 
framesize.");
+       }
+
+       private interface TestGateway extends RpcGateway {
+               void foobar(Object object) throws IOException, 
InterruptedException;
+       }
+
+       private static class TestEndpoint extends RpcEndpoint<TestGateway> {
+
+               private final LinkedBlockingQueue<Object> queue;
+
+               protected TestEndpoint(RpcService rpcService, 
LinkedBlockingQueue<Object> queue) {
+                       super(rpcService);
+                       this.queue = queue;
+               }
+
+               @RpcMethod
+               public void foobar(Object object) throws InterruptedException {
+                       queue.put(object);
+               }
+       }
+
+       private static class NonSerializableObject {
+               private final Object object = new Object();
+               private final int value;
+
+               NonSerializableObject(int value) {
+                       this.value = value;
+               }
+
+               @Override
+               public boolean equals(Object obj) {
+                       if (obj instanceof NonSerializableObject) {
+                               NonSerializableObject nonSerializableObject = 
(NonSerializableObject) obj;
+
+                               return value == nonSerializableObject.value;
+                       } else {
+                               return false;
+                       }
+               }
+
+               @Override
+               public int hashCode() {
+                       return value * 41;
+               }
+       }
+}

Reply via email to