[FLINK-4392] [rpc] Make RPC Service thread-safe
Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/86f21bf9 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/86f21bf9 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/86f21bf9 Branch: refs/heads/flip-6 Commit: 86f21bf9d6daee9bdf6af1e9052faa058203133f Parents: f5614a4 Author: Stephan Ewen <se...@apache.org> Authored: Sat Aug 13 19:11:47 2016 +0200 Committer: Till Rohrmann <trohrm...@apache.org> Committed: Wed Sep 21 11:39:13 2016 +0200 ---------------------------------------------------------------------- .../flink/runtime/rpc/akka/AkkaGateway.java | 3 +- .../flink/runtime/rpc/akka/AkkaRpcService.java | 92 +++++++++++++++----- 2 files changed, 70 insertions(+), 25 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/86f21bf9/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaGateway.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaGateway.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaGateway.java index a826e7d..ec3091c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaGateway.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaGateway.java @@ -19,11 +19,12 @@ package org.apache.flink.runtime.rpc.akka; import akka.actor.ActorRef; +import org.apache.flink.runtime.rpc.RpcGateway; /** * Interface for Akka based rpc gateways */ -interface AkkaGateway { +interface AkkaGateway extends RpcGateway { ActorRef getRpcServer(); } http://git-wip-us.apache.org/repos/asf/flink/blob/86f21bf9/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcService.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcService.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcService.java index 17983d0..448216c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcService.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcService.java @@ -28,47 +28,61 @@ import akka.actor.Props; import akka.dispatch.Mapper; import akka.pattern.AskableActorSelection; import akka.util.Timeout; + import org.apache.flink.runtime.akka.AkkaUtils; import org.apache.flink.runtime.rpc.MainThreadExecutor; import org.apache.flink.runtime.rpc.RpcGateway; import org.apache.flink.runtime.rpc.RpcEndpoint; import org.apache.flink.runtime.rpc.RpcService; -import org.apache.flink.util.Preconditions; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import scala.concurrent.Future; +import javax.annotation.concurrent.ThreadSafe; import java.lang.reflect.InvocationHandler; import java.lang.reflect.Proxy; -import java.util.Collection; import java.util.HashSet; +import java.util.Set; + +import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; /** - * Akka based {@link RpcService} implementation. The rpc service starts an Akka actor to receive - * rpcs from a {@link RpcGateway}. + * Akka based {@link RpcService} implementation. The RPC service starts an Akka actor to receive + * RPC invocations from a {@link RpcGateway}. */ +@ThreadSafe public class AkkaRpcService implements RpcService { + private static final Logger LOG = LoggerFactory.getLogger(AkkaRpcService.class); + private final Object lock = new Object(); + private final ActorSystem actorSystem; private final Timeout timeout; - private final Collection<ActorRef> actors = new HashSet<>(4); + private final Set<ActorRef> actors = new HashSet<>(4); + + private volatile boolean stopped; public AkkaRpcService(final ActorSystem actorSystem, final Timeout timeout) { - this.actorSystem = Preconditions.checkNotNull(actorSystem, "actor system"); - this.timeout = Preconditions.checkNotNull(timeout, "timeout"); + this.actorSystem = checkNotNull(actorSystem, "actor system"); + this.timeout = checkNotNull(timeout, "timeout"); } + // this method does not mutate state and is thus thread-safe @Override public <C extends RpcGateway> Future<C> connect(final String address, final Class<C> clazz) { - LOG.info("Try to connect to remote rpc server with address {}. Returning a {} gateway.", address, clazz.getName()); + checkState(!stopped, "RpcService is stopped"); - final ActorSelection actorSel = actorSystem.actorSelection(address); + LOG.debug("Try to connect to remote RPC endpoint with address {}. Returning a {} gateway.", + address, clazz.getName()); + final ActorSelection actorSel = actorSystem.actorSelection(address); final AskableActorSelection asker = new AskableActorSelection(actorSel); final Future<Object> identify = asker.ask(new Identify(42), timeout); - return identify.map(new Mapper<Object, C>(){ @Override public C apply(Object obj) { @@ -89,20 +103,29 @@ public class AkkaRpcService implements RpcService { @Override public <C extends RpcGateway, S extends RpcEndpoint<C>> C startServer(S rpcEndpoint) { - Preconditions.checkNotNull(rpcEndpoint, "rpc endpoint"); - - LOG.info("Start Akka rpc actor to handle rpcs for {}.", rpcEndpoint.getClass().getName()); + checkNotNull(rpcEndpoint, "rpc endpoint"); Props akkaRpcActorProps = Props.create(AkkaRpcActor.class, rpcEndpoint); + ActorRef actorRef; + + synchronized (lock) { + checkState(!stopped, "RpcService is stopped"); + actorRef = actorSystem.actorOf(akkaRpcActorProps); + actors.add(actorRef); + } - ActorRef actorRef = actorSystem.actorOf(akkaRpcActorProps); - actors.add(actorRef); + LOG.info("Starting RPC endpoint for {} at {} .", rpcEndpoint.getClass().getName(), actorRef.path()); InvocationHandler akkaInvocationHandler = new AkkaInvocationHandler(actorRef, timeout); + // Rather than using the System ClassLoader directly, we derive the ClassLoader + // from this class . That works better in cases where Flink runs embedded and all Flink + // code is loaded dynamically (for example from an OSGI bundle) through a custom ClassLoader + ClassLoader classLoader = getClass().getClassLoader(); + @SuppressWarnings("unchecked") C self = (C) Proxy.newProxyInstance( - ClassLoader.getSystemClassLoader(), + classLoader, new Class<?>[]{rpcEndpoint.getSelfGatewayType(), MainThreadExecutor.class, AkkaGateway.class}, akkaInvocationHandler); @@ -110,35 +133,56 @@ public class AkkaRpcService implements RpcService { } @Override - public <C extends RpcGateway> void stopServer(C selfGateway) { + public void stopServer(RpcGateway selfGateway) { if (selfGateway instanceof AkkaGateway) { AkkaGateway akkaClient = (AkkaGateway) selfGateway; - if (actors.contains(akkaClient.getRpcServer())) { - ActorRef selfActorRef = akkaClient.getRpcServer(); - - LOG.info("Stop Akka rpc actor {}.", selfActorRef.path()); + boolean fromThisService; + synchronized (lock) { + if (stopped) { + return; + } else { + fromThisService = actors.remove(akkaClient.getRpcServer()); + } + } + if (fromThisService) { + ActorRef selfActorRef = akkaClient.getRpcServer(); + LOG.info("Stopping RPC endpoint {}.", selfActorRef.path()); selfActorRef.tell(PoisonPill.getInstance(), ActorRef.noSender()); + } else { + LOG.debug("RPC endpoint {} already stopped or from different RPC service"); } } } @Override public void stopService() { - LOG.info("Stop Akka rpc service."); - actorSystem.shutdown(); + LOG.info("Stopping Akka RPC service."); + + synchronized (lock) { + if (stopped) { + return; + } + + stopped = true; + actorSystem.shutdown(); + actors.clear(); + } + actorSystem.awaitTermination(); } @Override public <C extends RpcGateway> String getAddress(C selfGateway) { + checkState(!stopped, "RpcService is stopped"); + if (selfGateway instanceof AkkaGateway) { ActorRef actorRef = ((AkkaGateway) selfGateway).getRpcServer(); return AkkaUtils.getAkkaURL(actorSystem, actorRef); } else { String className = AkkaGateway.class.getName(); - throw new RuntimeException("Cannot get address for non " + className + '.'); + throw new IllegalArgumentException("Cannot get address for non " + className + '.'); } } }