wonook closed pull request #45: [NEMO-103] Implement RPC between Client and Driver URL: https://github.com/apache/incubator-nemo/pull/45
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/client/src/main/java/edu/snu/nemo/client/DriverRPCServer.java b/client/src/main/java/edu/snu/nemo/client/DriverRPCServer.java new file mode 100644 index 00000000..16a07282 --- /dev/null +++ b/client/src/main/java/edu/snu/nemo/client/DriverRPCServer.java @@ -0,0 +1,196 @@ +/* + * Copyright (C) 2018 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package edu.snu.nemo.client; + +import com.google.protobuf.InvalidProtocolBufferException; +import edu.snu.nemo.conf.JobConf; +import edu.snu.nemo.runtime.common.comm.ControlMessage; +import org.apache.reef.annotations.audience.ClientSide; +import org.apache.reef.tang.Configuration; +import org.apache.reef.tang.Injector; +import org.apache.reef.tang.Tang; +import org.apache.reef.tang.exceptions.InjectionException; +import org.apache.reef.wake.EventHandler; +import org.apache.reef.wake.impl.SyncStage; +import org.apache.reef.wake.remote.RemoteConfiguration; +import org.apache.reef.wake.remote.address.LocalAddressProvider; +import org.apache.reef.wake.remote.impl.TransportEvent; +import org.apache.reef.wake.remote.transport.Link; +import org.apache.reef.wake.remote.transport.Transport; +import org.apache.reef.wake.remote.transport.netty.NettyMessagingTransport; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.concurrent.NotThreadSafe; +import java.util.HashMap; +import java.util.Map; + +/** + * Client-side RPC implementation for communication from/to Nemo Driver. + */ +@ClientSide +@NotThreadSafe +public final class DriverRPCServer { + private final Map<ControlMessage.DriverToClientMessageType, EventHandler<ControlMessage.DriverToClientMessage>> + handlers = new HashMap<>(); + private boolean isRunning = false; + private boolean isShutdown = false; + private Transport transport; + private Link link; + private String host; + + private static final Logger LOG = LoggerFactory.getLogger(DriverRPCServer.class); + + /** + * Registers handler for the given type of message. + * @param type the type of message + * @param handler handler implementation + * @return {@code this} + */ + public DriverRPCServer registerHandler(final ControlMessage.DriverToClientMessageType type, + final EventHandler<ControlMessage.DriverToClientMessage> handler) { + // Registering a handler after running the server is considered not a good practice. + ensureServerState(false); + if (handlers.putIfAbsent(type, handler) != null) { + throw new RuntimeException(String.format("A handler for %s already registered", type)); + } + return this; + } + + /** + * Runs the RPC server. + * Specifically, creates a {@link NettyMessagingTransport} and binds it to a listening port. + */ + public void run() { + // Calling 'run' multiple times is considered invalid, since it will override state variables like + // 'transport', and 'host'. + ensureServerState(false); + try { + final Injector injector = Tang.Factory.getTang().newInjector(); + final LocalAddressProvider localAddressProvider = injector.getInstance(LocalAddressProvider.class); + host = localAddressProvider.getLocalAddress(); + injector.bindVolatileParameter(RemoteConfiguration.HostAddress.class, host); + injector.bindVolatileParameter(RemoteConfiguration.Port.class, 0); + injector.bindVolatileParameter(RemoteConfiguration.RemoteServerStage.class, + new SyncStage<>(new ServerEventHandler())); + transport = injector.getInstance(NettyMessagingTransport.class); + LOG.info("DriverRPCServer running at {}", transport.getListeningPort()); + isRunning = true; + } catch (final InjectionException e) { + throw new RuntimeException(e); + } + } + + /** + * @return the listening port + */ + public int getListeningPort() { + // We cannot determine listening port if the server is not listening. + ensureServerState(true); + return transport.getListeningPort(); + } + + /** + * @return the host of the client + */ + public String getListeningHost() { + // Listening host is determined by LocalAddressProvider, in 'run' method. + ensureServerState(true); + return host; + } + + /** + * @return the configuration for RPC server listening information + */ + public Configuration getListeningConfiguration() { + return Tang.Factory.getTang().newConfigurationBuilder() + .bindNamedParameter(JobConf.ClientSideRPCServerHost.class, getListeningHost()) + .bindNamedParameter(JobConf.ClientSideRPCServerPort.class, String.valueOf(getListeningPort())) + .build(); + } + + /** + * Sends a message to driver. + * @param message message to send + */ + public void send(final ControlMessage.ClientToDriverMessage message) { + // This needs active 'link' between the driver and client. + // For the link to be alive, the driver should connect to DriverRPCServer. + // Thus, the server must be running to send a message to the driver. + ensureServerState(true); + if (link == null) { + throw new RuntimeException("The RPC server has not discovered NemoDriver yet"); + } + link.write(message.toByteArray()); + } + + /** + * Shut down the server. + */ + public void shutdown() { + // Shutting down a 'null' transport is invalid. Also, shutting down a server for multiple times is invalid. + ensureServerState(true); + try { + transport.close(); + } catch (final Exception e) { + throw new RuntimeException(e); + } finally { + isShutdown = true; + } + } + + /** + * Handles messages from driver. + */ + private final class ServerEventHandler implements EventHandler<TransportEvent> { + @Override + public void onNext(final TransportEvent transportEvent) { + final byte[] bytes = transportEvent.getData(); + final ControlMessage.DriverToClientMessage message; + try { + message = ControlMessage.DriverToClientMessage.parseFrom(bytes); + } catch (final InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + + final ControlMessage.DriverToClientMessageType type = message.getType(); + + if (type == ControlMessage.DriverToClientMessageType.DriverStarted) { + link = transportEvent.getLink(); + } + + final EventHandler<ControlMessage.DriverToClientMessage> handler = handlers.get(type); + if (handler == null) { + throw new RuntimeException(String.format("Handler for message type %s not registered", type)); + } else { + handler.onNext(message); + } + } + } + + /** + * Throws a {@link RuntimeException} if the server is shut down, or it has different state than the expected state. + * @param running the expected state of the server + */ + private void ensureServerState(final boolean running) { + if (isShutdown) { + throw new RuntimeException("The DriverRPCServer is already shutdown"); + } + if (running != isRunning) { + throw new RuntimeException(String.format("The DriverRPCServer is %s running", isRunning ? "already" : "not")); + } + } +} diff --git a/client/src/main/java/edu/snu/nemo/client/JobLauncher.java b/client/src/main/java/edu/snu/nemo/client/JobLauncher.java index 0b7ca67b..5c330d7c 100644 --- a/client/src/main/java/edu/snu/nemo/client/JobLauncher.java +++ b/client/src/main/java/edu/snu/nemo/client/JobLauncher.java @@ -19,6 +19,7 @@ import edu.snu.nemo.common.dag.DAG; import edu.snu.nemo.conf.JobConf; import edu.snu.nemo.driver.NemoDriver; +import edu.snu.nemo.runtime.common.comm.ControlMessage; import edu.snu.nemo.runtime.common.message.MessageEnvironment; import edu.snu.nemo.runtime.common.message.MessageParameters; import org.apache.commons.lang3.SerializationUtils; @@ -58,6 +59,7 @@ private static Configuration jobAndDriverConf = null; private static Configuration deployModeConf = null; private static Configuration builtJobConf = null; + private static String serializedDAG; /** * private constructor. @@ -72,6 +74,16 @@ private JobLauncher() { * @throws Exception exception on the way. */ public static void main(final String[] args) throws Exception { + final DriverRPCServer driverRPCServer = new DriverRPCServer(); + // Registers actions for launching the DAG. + driverRPCServer + .registerHandler(ControlMessage.DriverToClientMessageType.DriverStarted, event -> { }) + .registerHandler(ControlMessage.DriverToClientMessageType.ResourceReady, event -> + driverRPCServer.send(ControlMessage.ClientToDriverMessage.newBuilder() + .setType(ControlMessage.ClientToDriverMessageType.LaunchDAG) + .setLaunchDAG(ControlMessage.LaunchDAGMessage.newBuilder().setDag(serializedDAG).build()).build())) + .run(); + // Get Job and Driver Confs builtJobConf = getJobConf(args); final Configuration driverConf = getDriverConf(builtJobConf); @@ -82,13 +94,15 @@ public static void main(final String[] args) throws Exception { // Merge Job and Driver Confs jobAndDriverConf = Configurations.merge(builtJobConf, driverConf, driverNcsConf, driverMessageConfg, - executorResourceConfig); + executorResourceConfig, driverRPCServer.getListeningConfiguration()); // Get DeployMode Conf deployModeConf = Configurations.merge(getDeployModeConf(builtJobConf), clientConf); // Launch client main runUserProgramMain(builtJobConf); + + driverRPCServer.shutdown(); } /** @@ -102,13 +116,10 @@ public static void launchDAG(final DAG dag) { if (jobAndDriverConf == null || deployModeConf == null || builtJobConf == null) { throw new RuntimeException("Configuration for launching driver is not ready"); } - final String serializedDAG = Base64.getEncoder().encodeToString(SerializationUtils.serialize(dag)); - final Configuration dagConf = TANG.newConfigurationBuilder() - .bindNamedParameter(JobConf.SerializedDAG.class, serializedDAG) - .build(); + serializedDAG = Base64.getEncoder().encodeToString(SerializationUtils.serialize(dag)); // Launch and wait indefinitely for the job to finish final LauncherStatus launcherStatus = DriverLauncher.getLauncher(deployModeConf) - .run(Configurations.merge(jobAndDriverConf, dagConf)); + .run(jobAndDriverConf); final Optional<Throwable> possibleError = launcherStatus.getError(); if (possibleError.isPresent()) { throw new RuntimeException(possibleError.get()); diff --git a/conf/src/main/java/edu/snu/nemo/conf/JobConf.java b/conf/src/main/java/edu/snu/nemo/conf/JobConf.java index da3d6713..1e4ef4de 100644 --- a/conf/src/main/java/edu/snu/nemo/conf/JobConf.java +++ b/conf/src/main/java/edu/snu/nemo/conf/JobConf.java @@ -73,6 +73,22 @@ public final class GlusterVolumeDirectory implements Name<String> { } + //////////////////////////////// Client-Driver RPC + + /** + * Host of the client-side RPC server. + */ + @NamedParameter + public final class ClientSideRPCServerHost implements Name<String> { + } + + /** + * Port of the client-side RPC server. + */ + @NamedParameter + public final class ClientSideRPCServerPort implements Name<Integer> { + } + //////////////////////////////// Compiler Configurations /** @@ -227,13 +243,6 @@ public final class ExecutorId implements Name<String> { } - /** - * Serialized {edu.snu.nemo.common.dag.DAG} from user main method. - */ - @NamedParameter(doc = "String serialized DAG") - public final class SerializedDAG implements Name<String> { - } - public static final RequiredParameter<String> EXECUTOR_ID = new RequiredParameter<>(); public static final RequiredParameter<String> JOB_ID = new RequiredParameter<>(); public static final OptionalParameter<String> LOCAL_DISK_DIRECTORY = new OptionalParameter<>(); diff --git a/runtime/common/src/main/proto/ControlMessage.proto b/runtime/common/src/main/proto/ControlMessage.proto index 664734b1..f6bd527e 100644 --- a/runtime/common/src/main/proto/ControlMessage.proto +++ b/runtime/common/src/main/proto/ControlMessage.proto @@ -19,6 +19,28 @@ package protobuf; option java_package = "edu.snu.nemo.runtime.common.comm"; option java_outer_classname = "ControlMessage"; +enum ClientToDriverMessageType { + LaunchDAG = 0; +} + +message ClientToDriverMessage { + required ClientToDriverMessageType type = 1; + optional LaunchDAGMessage launchDAG = 2; +} + +message LaunchDAGMessage { + required string dag = 1; +} + +enum DriverToClientMessageType { + DriverStarted = 0; + ResourceReady = 1; +} + +message DriverToClientMessage { + required DriverToClientMessageType type = 1; +} + enum MessageType { TaskStateChanged = 0; ScheduleTask = 1; diff --git a/runtime/driver/src/main/java/edu/snu/nemo/driver/ClientRPC.java b/runtime/driver/src/main/java/edu/snu/nemo/driver/ClientRPC.java new file mode 100644 index 00000000..82698f03 --- /dev/null +++ b/runtime/driver/src/main/java/edu/snu/nemo/driver/ClientRPC.java @@ -0,0 +1,167 @@ +/* + * Copyright (C) 2018 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package edu.snu.nemo.driver; + +import com.google.protobuf.InvalidProtocolBufferException; +import edu.snu.nemo.conf.JobConf; +import edu.snu.nemo.runtime.common.comm.ControlMessage; +import org.apache.reef.tang.annotations.Parameter; +import org.apache.reef.wake.EventHandler; +import org.apache.reef.wake.impl.SyncStage; +import org.apache.reef.wake.remote.Encoder; +import org.apache.reef.wake.remote.address.LocalAddressProvider; +import org.apache.reef.wake.remote.impl.TransportEvent; +import org.apache.reef.wake.remote.transport.Link; +import org.apache.reef.wake.remote.transport.LinkListener; +import org.apache.reef.wake.remote.transport.Transport; +import org.apache.reef.wake.remote.transport.TransportFactory; + +import javax.inject.Inject; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Driver-side RPC implementation for communication from/to Nemo Client. + */ +public final class ClientRPC { + private static final DriverToClientMessageEncoder ENCODER = new DriverToClientMessageEncoder(); + private static final ClientRPCLinkListener LINK_LISTENER = new ClientRPCLinkListener(); + private static final int RETRY_COUNT = 10; + private static final int RETRY_TIMEOUT = 100; + + private final Map<ControlMessage.ClientToDriverMessageType, EventHandler<ControlMessage.ClientToDriverMessage>> + handlers = new ConcurrentHashMap<>(); + private final Transport transport; + private final Link<ControlMessage.DriverToClientMessage> link; + private volatile boolean isClosed = false; + + @Inject + private ClientRPC(final TransportFactory transportFactory, + final LocalAddressProvider localAddressProvider, + @Parameter(JobConf.ClientSideRPCServerHost.class) final String clientHost, + @Parameter(JobConf.ClientSideRPCServerPort.class) final int clientPort) throws IOException { + transport = transportFactory.newInstance(localAddressProvider.getLocalAddress(), + 0, new SyncStage<>(new RPCEventHandler()), null, RETRY_COUNT, RETRY_TIMEOUT); + final SocketAddress clientAddress = new InetSocketAddress(clientHost, clientPort); + link = transport.open(clientAddress, ENCODER, LINK_LISTENER); + } + + /** + * Registers handler for the given type of message. + * @param type the type of message + * @param handler handler implementation + * @return {@code this} + */ + public ClientRPC registerHandler(final ControlMessage.ClientToDriverMessageType type, + final EventHandler<ControlMessage.ClientToDriverMessage> handler) { + if (handlers.putIfAbsent(type, handler) != null) { + throw new RuntimeException(String.format("A handler for %s already registered", type)); + } + return this; + } + + /** + * Shuts down the transport. + */ + public void shutdown() { + ensureRunning(); + try { + transport.close(); + } catch (final Exception e) { + throw new RuntimeException(e); + } finally { + isClosed = true; + } + } + + /** + * Write message to client. + * @param message message to send. + */ + public void send(final ControlMessage.DriverToClientMessage message) { + ensureRunning(); + link.write(message); + } + + /** + * Handles message from client. + * @param message message to process + */ + private void handleMessage(final ControlMessage.ClientToDriverMessage message) { + final ControlMessage.ClientToDriverMessageType type = message.getType(); + final EventHandler<ControlMessage.ClientToDriverMessage> handler = handlers.get(type); + if (handler == null) { + throw new RuntimeException(String.format("Handler for message type %s not registered", type)); + } else { + handler.onNext(message); + } + } + + /** + * Provides event handler for messages from client. + */ + private final class RPCEventHandler implements EventHandler<TransportEvent> { + @Override + public void onNext(final TransportEvent transportEvent) { + try { + final byte[] data = transportEvent.getData(); + final ControlMessage.ClientToDriverMessage message = ControlMessage.ClientToDriverMessage.parseFrom(data); + handleMessage(message); + } catch (final InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + } + } + + /** + * Ensure the Transport is running. + */ + private void ensureRunning() { + if (isClosed) { + throw new RuntimeException("The ClientRPC is already closed"); + } + } + + /** + * Provides encoder for {@link edu.snu.nemo.runtime.common.comm.ControlMessage.DriverToClientMessage}. + */ + private static final class DriverToClientMessageEncoder implements Encoder<ControlMessage.DriverToClientMessage> { + @Override + public byte[] encode(final ControlMessage.DriverToClientMessage driverToClientMessage) { + return driverToClientMessage.toByteArray(); + } + } + + /** + * Provides {@link LinkListener}. + */ + private static final class ClientRPCLinkListener implements LinkListener<ControlMessage.DriverToClientMessage> { + + @Override + public void onSuccess(final ControlMessage.DriverToClientMessage driverToClientMessage) { + } + + @Override + public void onException(final Throwable throwable, + final SocketAddress socketAddress, + final ControlMessage.DriverToClientMessage driverToClientMessage) { + throw new RuntimeException(throwable); + } + } +} diff --git a/runtime/driver/src/main/java/edu/snu/nemo/driver/NemoDriver.java b/runtime/driver/src/main/java/edu/snu/nemo/driver/NemoDriver.java index e2e263cd..8102d645 100644 --- a/runtime/driver/src/main/java/edu/snu/nemo/driver/NemoDriver.java +++ b/runtime/driver/src/main/java/edu/snu/nemo/driver/NemoDriver.java @@ -18,6 +18,7 @@ import edu.snu.nemo.common.ir.IdManager; import edu.snu.nemo.conf.JobConf; import edu.snu.nemo.runtime.common.RuntimeIdGenerator; +import edu.snu.nemo.runtime.common.comm.ControlMessage; import edu.snu.nemo.runtime.common.message.MessageParameters; import edu.snu.nemo.runtime.master.RuntimeMaster; import org.apache.reef.annotations.audience.DriverSide; @@ -67,9 +68,9 @@ private final String jobId; private final String localDirectory; private final String glusterDirectory; + private final ClientRPC clientRPC; // Client for sending log messages - private final JobMessageObserver client; private final RemoteClientMessageLoggingHandler handler; @Inject @@ -78,6 +79,7 @@ private NemoDriver(final UserApplicationRunner userApplicationRunner, final NameServer nameServer, final LocalAddressProvider localAddressProvider, final JobMessageObserver client, + final ClientRPC clientRPC, @Parameter(JobConf.ExecutorJsonContents.class) final String resourceSpecificationString, @Parameter(JobConf.JobId.class) final String jobId, @Parameter(JobConf.FileDirectory.class) final String localDirectory, @@ -91,8 +93,13 @@ private NemoDriver(final UserApplicationRunner userApplicationRunner, this.jobId = jobId; this.localDirectory = localDirectory; this.glusterDirectory = glusterDirectory; - this.client = client; this.handler = new RemoteClientMessageLoggingHandler(client); + this.clientRPC = clientRPC; + clientRPC.registerHandler(ControlMessage.ClientToDriverMessageType.LaunchDAG, + message -> startSchedulingUserApplication(message.getLaunchDAG().getDag())); + // Send DriverStarted message to the client + clientRPC.send(ControlMessage.DriverToClientMessage.newBuilder() + .setType(ControlMessage.DriverToClientMessageType.DriverStarted).build()); } /** @@ -135,15 +142,19 @@ public void onNext(final ActiveContext activeContext) { final boolean finalExecutorLaunched = runtimeMaster.onExecutorLaunched(activeContext); if (finalExecutorLaunched) { - startSchedulingUserApplication(); + clientRPC.send(ControlMessage.DriverToClientMessage.newBuilder() + .setType(ControlMessage.DriverToClientMessageType.ResourceReady).build()); } } } - private void startSchedulingUserApplication() { + /** + * Start user application. + */ + public void startSchedulingUserApplication(final String dagString) { // Launch user application (with a new thread) final ExecutorService userApplicationRunnerThread = Executors.newSingleThreadExecutor(); - userApplicationRunnerThread.execute(userApplicationRunner); + userApplicationRunnerThread.execute(() -> userApplicationRunner.run(dagString)); userApplicationRunnerThread.shutdown(); } @@ -175,6 +186,7 @@ public void onNext(final FailedContext failedContext) { @Override public void onNext(final StopTime stopTime) { handler.close(); + clientRPC.shutdown(); } } diff --git a/runtime/driver/src/main/java/edu/snu/nemo/driver/UserApplicationRunner.java b/runtime/driver/src/main/java/edu/snu/nemo/driver/UserApplicationRunner.java index 3e415d8e..6ba615f4 100644 --- a/runtime/driver/src/main/java/edu/snu/nemo/driver/UserApplicationRunner.java +++ b/runtime/driver/src/main/java/edu/snu/nemo/driver/UserApplicationRunner.java @@ -42,11 +42,10 @@ /** * Compiles and runs User application. */ -public final class UserApplicationRunner implements Runnable { +public final class UserApplicationRunner { private static final Logger LOG = LoggerFactory.getLogger(UserApplicationRunner.class.getName()); private final String dagDirectory; - private final String dagString; private final String optimizationPolicyCanonicalName; private final int maxScheduleAttempt; @@ -58,14 +57,12 @@ @Inject private UserApplicationRunner(@Parameter(JobConf.DAGDirectory.class) final String dagDirectory, - @Parameter(JobConf.SerializedDAG.class) final String dagString, @Parameter(JobConf.OptimizationPolicy.class) final String optimizationPolicy, @Parameter(JobConf.MaxScheduleAttempt.class) final int maxScheduleAttempt, final PubSubEventHandlerWrapper pubSubEventHandlerWrapper, final Injector injector, final RuntimeMaster runtimeMaster) { this.dagDirectory = dagDirectory; - this.dagString = dagString; this.optimizationPolicyCanonicalName = optimizationPolicy; this.maxScheduleAttempt = maxScheduleAttempt; this.injector = injector; @@ -74,8 +71,14 @@ private UserApplicationRunner(@Parameter(JobConf.DAGDirectory.class) final Strin this.pubSubWrapper = pubSubEventHandlerWrapper; } - @Override - public void run() { + /** + * Run the user program submitted by Nemo Client. + * Specifically, deserialize DAG from Client, optimize it, generate physical plan, + * and tell {@link RuntimeMaster} to execute the plan. + * + * @param dagString Serialized IR DAG from Nemo Client. + */ + public void run(final String dagString) { try { LOG.info("##### Nemo Compiler #####"); diff --git a/tests/src/test/java/edu/snu/nemo/client/ClientDriverRPCTest.java b/tests/src/test/java/edu/snu/nemo/client/ClientDriverRPCTest.java new file mode 100644 index 00000000..af94b597 --- /dev/null +++ b/tests/src/test/java/edu/snu/nemo/client/ClientDriverRPCTest.java @@ -0,0 +1,97 @@ +/* + * Copyright (C) 2018 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package edu.snu.nemo.client; + +import edu.snu.nemo.driver.ClientRPC; +import edu.snu.nemo.runtime.common.comm.ControlMessage; +import org.apache.reef.tang.Injector; +import org.apache.reef.tang.Tang; +import org.apache.reef.tang.exceptions.InjectionException; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.util.concurrent.CountDownLatch; + +/** + * Test for communication between {@link DriverRPCServer} and {@link ClientRPC}. + */ +public final class ClientDriverRPCTest { + private DriverRPCServer driverRPCServer; + private ClientRPC clientRPC; + @Before + public void setupDriverRPCServer() { + // Initialize DriverRPCServer. + driverRPCServer = new DriverRPCServer(); + } + + private void setupClientRPC() throws InjectionException { + driverRPCServer.run(); + final Injector clientRPCInjector = Tang.Factory.getTang().newInjector(driverRPCServer.getListeningConfiguration()); + clientRPC = clientRPCInjector.getInstance(ClientRPC.class); + } + + @After + public void cleanup() { + driverRPCServer.shutdown(); + clientRPC.shutdown(); + } + + /** + * Test with empty set of handlers. + * @throws InjectionException on Exceptions on creating {@link ClientRPC}. + */ + @Test + public void testRPCSetup() throws InjectionException { + setupClientRPC(); + } + + /** + * Test with basic request method from driver to client. + * @throws InjectionException on Exceptions on creating {@link ClientRPC}. + * @throws InterruptedException when interrupted while waiting EventHandler invocation + */ + @Test + public void testDriverToClientMethodInvocation() throws InjectionException, InterruptedException { + final CountDownLatch latch = new CountDownLatch(1); + driverRPCServer.registerHandler(ControlMessage.DriverToClientMessageType.DriverStarted, + msg -> latch.countDown()); + setupClientRPC(); + clientRPC.send(ControlMessage.DriverToClientMessage.newBuilder() + .setType(ControlMessage.DriverToClientMessageType.DriverStarted).build()); + latch.await(); + } + + /** + * Test with request-response RPC between client and driver. + * @throws InjectionException on Exceptions on creating {@link ClientRPC}. + * @throws InterruptedException when interrupted while waiting EventHandler invocation + */ + @Test + public void testBetweenClientAndDriver() throws InjectionException, InterruptedException { + final CountDownLatch latch = new CountDownLatch(1); + driverRPCServer.registerHandler(ControlMessage.DriverToClientMessageType.DriverStarted, + msg -> driverRPCServer.send(ControlMessage.ClientToDriverMessage.newBuilder() + .setType(ControlMessage.ClientToDriverMessageType.LaunchDAG) + .setLaunchDAG(ControlMessage.LaunchDAGMessage.newBuilder().setDag("").build()) + .build())); + setupClientRPC(); + clientRPC.registerHandler(ControlMessage.ClientToDriverMessageType.LaunchDAG, msg -> latch.countDown()); + clientRPC.send(ControlMessage.DriverToClientMessage.newBuilder() + .setType(ControlMessage.DriverToClientMessageType.DriverStarted).build()); + latch.await(); + } +} ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
