You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nemo.apache.org by wo...@apache.org on 2018/06/19 04:49:54 UTC
[incubator-nemo] branch master updated: [NEMO-103] Implement RPC
between Client and Driver (#45)
This is an automated email from the ASF dual-hosted git repository.
wonook pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-nemo.git
The following commit(s) were added to refs/heads/master by this push:
new 39482ae [NEMO-103] Implement RPC between Client and Driver (#45)
39482ae is described below
commit 39482ae358fbe343a51060d90e46e3bf7bbf0f15
Author: Jangho Seo <ja...@jangho.io>
AuthorDate: Tue Jun 19 13:49:52 2018 +0900
[NEMO-103] Implement RPC between Client and Driver (#45)
JIRA: [NEMO-103: Communication between Driver and Client](https://issues.apache.org/jira/projects/NEMO/issues/NEMO-103)
**Major changes:**
- Implement RPC between NemoClient and NemoDriver using NettyMessagingTransport
- Use the RPC stack to submit DAG to driver, allowing multiple-DAG submission to one REEF instance.
**Minor changes to note:**
- N/A
**Tests for the changes:**
- Added 'ClientDriverRPCTest'
- Existing integration tests also cover this change.
**Other comments:**
- N/A
resolves [NEMO-103](https://issues.apache.org/jira/projects/NEMO/issues/NEMO-103)
---
.../java/edu/snu/nemo/client/DriverRPCServer.java | 196 +++++++++++++++++++++
.../main/java/edu/snu/nemo/client/JobLauncher.java | 23 ++-
conf/src/main/java/edu/snu/nemo/conf/JobConf.java | 23 ++-
runtime/common/src/main/proto/ControlMessage.proto | 22 +++
.../main/java/edu/snu/nemo/driver/ClientRPC.java | 167 ++++++++++++++++++
.../main/java/edu/snu/nemo/driver/NemoDriver.java | 22 ++-
.../edu/snu/nemo/driver/UserApplicationRunner.java | 15 +-
.../edu/snu/nemo/client/ClientDriverRPCTest.java | 97 ++++++++++
8 files changed, 541 insertions(+), 24 deletions(-)
diff --git a/client/src/main/java/edu/snu/nemo/client/DriverRPCServer.java b/client/src/main/java/edu/snu/nemo/client/DriverRPCServer.java
new file mode 100644
index 0000000..16a0728
--- /dev/null
+++ b/client/src/main/java/edu/snu/nemo/client/DriverRPCServer.java
@@ -0,0 +1,196 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.client;
+
+import com.google.protobuf.InvalidProtocolBufferException;
+import edu.snu.nemo.conf.JobConf;
+import edu.snu.nemo.runtime.common.comm.ControlMessage;
+import org.apache.reef.annotations.audience.ClientSide;
+import org.apache.reef.tang.Configuration;
+import org.apache.reef.tang.Injector;
+import org.apache.reef.tang.Tang;
+import org.apache.reef.tang.exceptions.InjectionException;
+import org.apache.reef.wake.EventHandler;
+import org.apache.reef.wake.impl.SyncStage;
+import org.apache.reef.wake.remote.RemoteConfiguration;
+import org.apache.reef.wake.remote.address.LocalAddressProvider;
+import org.apache.reef.wake.remote.impl.TransportEvent;
+import org.apache.reef.wake.remote.transport.Link;
+import org.apache.reef.wake.remote.transport.Transport;
+import org.apache.reef.wake.remote.transport.netty.NettyMessagingTransport;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.concurrent.NotThreadSafe;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Client-side RPC implementation for communication from/to Nemo Driver.
+ */
+@ClientSide
+@NotThreadSafe
+public final class DriverRPCServer {
+ private final Map<ControlMessage.DriverToClientMessageType, EventHandler<ControlMessage.DriverToClientMessage>>
+ handlers = new HashMap<>();
+ private boolean isRunning = false;
+ private boolean isShutdown = false;
+ private Transport transport;
+ private Link link;
+ private String host;
+
+ private static final Logger LOG = LoggerFactory.getLogger(DriverRPCServer.class);
+
+ /**
+ * Registers handler for the given type of message.
+ * @param type the type of message
+ * @param handler handler implementation
+ * @return {@code this}
+ */
+ public DriverRPCServer registerHandler(final ControlMessage.DriverToClientMessageType type,
+ final EventHandler<ControlMessage.DriverToClientMessage> handler) {
+ // Registering a handler after running the server is considered not a good practice.
+ ensureServerState(false);
+ if (handlers.putIfAbsent(type, handler) != null) {
+ throw new RuntimeException(String.format("A handler for %s already registered", type));
+ }
+ return this;
+ }
+
+ /**
+ * Runs the RPC server.
+ * Specifically, creates a {@link NettyMessagingTransport} and binds it to a listening port.
+ */
+ public void run() {
+ // Calling 'run' multiple times is considered invalid, since it will override state variables like
+ // 'transport', and 'host'.
+ ensureServerState(false);
+ try {
+ final Injector injector = Tang.Factory.getTang().newInjector();
+ final LocalAddressProvider localAddressProvider = injector.getInstance(LocalAddressProvider.class);
+ host = localAddressProvider.getLocalAddress();
+ injector.bindVolatileParameter(RemoteConfiguration.HostAddress.class, host);
+ injector.bindVolatileParameter(RemoteConfiguration.Port.class, 0);
+ injector.bindVolatileParameter(RemoteConfiguration.RemoteServerStage.class,
+ new SyncStage<>(new ServerEventHandler()));
+ transport = injector.getInstance(NettyMessagingTransport.class);
+ LOG.info("DriverRPCServer running at {}", transport.getListeningPort());
+ isRunning = true;
+ } catch (final InjectionException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ /**
+ * @return the listening port
+ */
+ public int getListeningPort() {
+ // We cannot determine listening port if the server is not listening.
+ ensureServerState(true);
+ return transport.getListeningPort();
+ }
+
+ /**
+ * @return the host of the client
+ */
+ public String getListeningHost() {
+ // Listening host is determined by LocalAddressProvider, in 'run' method.
+ ensureServerState(true);
+ return host;
+ }
+
+ /**
+ * @return the configuration for RPC server listening information
+ */
+ public Configuration getListeningConfiguration() {
+ return Tang.Factory.getTang().newConfigurationBuilder()
+ .bindNamedParameter(JobConf.ClientSideRPCServerHost.class, getListeningHost())
+ .bindNamedParameter(JobConf.ClientSideRPCServerPort.class, String.valueOf(getListeningPort()))
+ .build();
+ }
+
+ /**
+ * Sends a message to driver.
+ * @param message message to send
+ */
+ public void send(final ControlMessage.ClientToDriverMessage message) {
+ // This needs active 'link' between the driver and client.
+ // For the link to be alive, the driver should connect to DriverRPCServer.
+ // Thus, the server must be running to send a message to the driver.
+ ensureServerState(true);
+ if (link == null) {
+ throw new RuntimeException("The RPC server has not discovered NemoDriver yet");
+ }
+ link.write(message.toByteArray());
+ }
+
+ /**
+ * Shut down the server.
+ */
+ public void shutdown() {
+ // Shutting down a 'null' transport is invalid. Also, shutting down a server for multiple times is invalid.
+ ensureServerState(true);
+ try {
+ transport.close();
+ } catch (final Exception e) {
+ throw new RuntimeException(e);
+ } finally {
+ isShutdown = true;
+ }
+ }
+
+ /**
+ * Handles messages from driver.
+ */
+ private final class ServerEventHandler implements EventHandler<TransportEvent> {
+ @Override
+ public void onNext(final TransportEvent transportEvent) {
+ final byte[] bytes = transportEvent.getData();
+ final ControlMessage.DriverToClientMessage message;
+ try {
+ message = ControlMessage.DriverToClientMessage.parseFrom(bytes);
+ } catch (final InvalidProtocolBufferException e) {
+ throw new RuntimeException(e);
+ }
+
+ final ControlMessage.DriverToClientMessageType type = message.getType();
+
+ if (type == ControlMessage.DriverToClientMessageType.DriverStarted) {
+ link = transportEvent.getLink();
+ }
+
+ final EventHandler<ControlMessage.DriverToClientMessage> handler = handlers.get(type);
+ if (handler == null) {
+ throw new RuntimeException(String.format("Handler for message type %s not registered", type));
+ } else {
+ handler.onNext(message);
+ }
+ }
+ }
+
+ /**
+ * Throws a {@link RuntimeException} if the server is shut down, or it has different state than the expected state.
+ * @param running the expected state of the server
+ */
+ private void ensureServerState(final boolean running) {
+ if (isShutdown) {
+ throw new RuntimeException("The DriverRPCServer is already shutdown");
+ }
+ if (running != isRunning) {
+ throw new RuntimeException(String.format("The DriverRPCServer is %s running", isRunning ? "already" : "not"));
+ }
+ }
+}
diff --git a/client/src/main/java/edu/snu/nemo/client/JobLauncher.java b/client/src/main/java/edu/snu/nemo/client/JobLauncher.java
index 0b7ca67..5c330d7 100644
--- a/client/src/main/java/edu/snu/nemo/client/JobLauncher.java
+++ b/client/src/main/java/edu/snu/nemo/client/JobLauncher.java
@@ -19,6 +19,7 @@ import com.google.common.annotations.VisibleForTesting;
import edu.snu.nemo.common.dag.DAG;
import edu.snu.nemo.conf.JobConf;
import edu.snu.nemo.driver.NemoDriver;
+import edu.snu.nemo.runtime.common.comm.ControlMessage;
import edu.snu.nemo.runtime.common.message.MessageEnvironment;
import edu.snu.nemo.runtime.common.message.MessageParameters;
import org.apache.commons.lang3.SerializationUtils;
@@ -58,6 +59,7 @@ public final class JobLauncher {
private static Configuration jobAndDriverConf = null;
private static Configuration deployModeConf = null;
private static Configuration builtJobConf = null;
+ private static String serializedDAG;
/**
* private constructor.
@@ -72,6 +74,16 @@ public final class JobLauncher {
* @throws Exception exception on the way.
*/
public static void main(final String[] args) throws Exception {
+ final DriverRPCServer driverRPCServer = new DriverRPCServer();
+ // Registers actions for launching the DAG.
+ driverRPCServer
+ .registerHandler(ControlMessage.DriverToClientMessageType.DriverStarted, event -> { })
+ .registerHandler(ControlMessage.DriverToClientMessageType.ResourceReady, event ->
+ driverRPCServer.send(ControlMessage.ClientToDriverMessage.newBuilder()
+ .setType(ControlMessage.ClientToDriverMessageType.LaunchDAG)
+ .setLaunchDAG(ControlMessage.LaunchDAGMessage.newBuilder().setDag(serializedDAG).build()).build()))
+ .run();
+
// Get Job and Driver Confs
builtJobConf = getJobConf(args);
final Configuration driverConf = getDriverConf(builtJobConf);
@@ -82,13 +94,15 @@ public final class JobLauncher {
// Merge Job and Driver Confs
jobAndDriverConf = Configurations.merge(builtJobConf, driverConf, driverNcsConf, driverMessageConfg,
- executorResourceConfig);
+ executorResourceConfig, driverRPCServer.getListeningConfiguration());
// Get DeployMode Conf
deployModeConf = Configurations.merge(getDeployModeConf(builtJobConf), clientConf);
// Launch client main
runUserProgramMain(builtJobConf);
+
+ driverRPCServer.shutdown();
}
/**
@@ -102,13 +116,10 @@ public final class JobLauncher {
if (jobAndDriverConf == null || deployModeConf == null || builtJobConf == null) {
throw new RuntimeException("Configuration for launching driver is not ready");
}
- final String serializedDAG = Base64.getEncoder().encodeToString(SerializationUtils.serialize(dag));
- final Configuration dagConf = TANG.newConfigurationBuilder()
- .bindNamedParameter(JobConf.SerializedDAG.class, serializedDAG)
- .build();
+ serializedDAG = Base64.getEncoder().encodeToString(SerializationUtils.serialize(dag));
// Launch and wait indefinitely for the job to finish
final LauncherStatus launcherStatus = DriverLauncher.getLauncher(deployModeConf)
- .run(Configurations.merge(jobAndDriverConf, dagConf));
+ .run(jobAndDriverConf);
final Optional<Throwable> possibleError = launcherStatus.getError();
if (possibleError.isPresent()) {
throw new RuntimeException(possibleError.get());
diff --git a/conf/src/main/java/edu/snu/nemo/conf/JobConf.java b/conf/src/main/java/edu/snu/nemo/conf/JobConf.java
index da3d671..1e4ef4d 100644
--- a/conf/src/main/java/edu/snu/nemo/conf/JobConf.java
+++ b/conf/src/main/java/edu/snu/nemo/conf/JobConf.java
@@ -73,6 +73,22 @@ public final class JobConf extends ConfigurationModuleBuilder {
public final class GlusterVolumeDirectory implements Name<String> {
}
+ //////////////////////////////// Client-Driver RPC
+
+ /**
+ * Host of the client-side RPC server.
+ */
+ @NamedParameter
+ public final class ClientSideRPCServerHost implements Name<String> {
+ }
+
+ /**
+ * Port of the client-side RPC server.
+ */
+ @NamedParameter
+ public final class ClientSideRPCServerPort implements Name<Integer> {
+ }
+
//////////////////////////////// Compiler Configurations
/**
@@ -227,13 +243,6 @@ public final class JobConf extends ConfigurationModuleBuilder {
public final class ExecutorId implements Name<String> {
}
- /**
- * Serialized {edu.snu.nemo.common.dag.DAG} from user main method.
- */
- @NamedParameter(doc = "String serialized DAG")
- public final class SerializedDAG implements Name<String> {
- }
-
public static final RequiredParameter<String> EXECUTOR_ID = new RequiredParameter<>();
public static final RequiredParameter<String> JOB_ID = new RequiredParameter<>();
public static final OptionalParameter<String> LOCAL_DISK_DIRECTORY = new OptionalParameter<>();
diff --git a/runtime/common/src/main/proto/ControlMessage.proto b/runtime/common/src/main/proto/ControlMessage.proto
index 664734b..f6bd527 100644
--- a/runtime/common/src/main/proto/ControlMessage.proto
+++ b/runtime/common/src/main/proto/ControlMessage.proto
@@ -19,6 +19,28 @@ package protobuf;
option java_package = "edu.snu.nemo.runtime.common.comm";
option java_outer_classname = "ControlMessage";
+enum ClientToDriverMessageType {
+ LaunchDAG = 0;
+}
+
+message ClientToDriverMessage {
+ required ClientToDriverMessageType type = 1;
+ optional LaunchDAGMessage launchDAG = 2;
+}
+
+message LaunchDAGMessage {
+ required string dag = 1;
+}
+
+enum DriverToClientMessageType {
+ DriverStarted = 0;
+ ResourceReady = 1;
+}
+
+message DriverToClientMessage {
+ required DriverToClientMessageType type = 1;
+}
+
enum MessageType {
TaskStateChanged = 0;
ScheduleTask = 1;
diff --git a/runtime/driver/src/main/java/edu/snu/nemo/driver/ClientRPC.java b/runtime/driver/src/main/java/edu/snu/nemo/driver/ClientRPC.java
new file mode 100644
index 0000000..82698f0
--- /dev/null
+++ b/runtime/driver/src/main/java/edu/snu/nemo/driver/ClientRPC.java
@@ -0,0 +1,167 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.driver;
+
+import com.google.protobuf.InvalidProtocolBufferException;
+import edu.snu.nemo.conf.JobConf;
+import edu.snu.nemo.runtime.common.comm.ControlMessage;
+import org.apache.reef.tang.annotations.Parameter;
+import org.apache.reef.wake.EventHandler;
+import org.apache.reef.wake.impl.SyncStage;
+import org.apache.reef.wake.remote.Encoder;
+import org.apache.reef.wake.remote.address.LocalAddressProvider;
+import org.apache.reef.wake.remote.impl.TransportEvent;
+import org.apache.reef.wake.remote.transport.Link;
+import org.apache.reef.wake.remote.transport.LinkListener;
+import org.apache.reef.wake.remote.transport.Transport;
+import org.apache.reef.wake.remote.transport.TransportFactory;
+
+import javax.inject.Inject;
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.net.SocketAddress;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+/**
+ * Driver-side RPC implementation for communication from/to Nemo Client.
+ */
+public final class ClientRPC {
+ private static final DriverToClientMessageEncoder ENCODER = new DriverToClientMessageEncoder();
+ private static final ClientRPCLinkListener LINK_LISTENER = new ClientRPCLinkListener();
+ private static final int RETRY_COUNT = 10;
+ private static final int RETRY_TIMEOUT = 100;
+
+ private final Map<ControlMessage.ClientToDriverMessageType, EventHandler<ControlMessage.ClientToDriverMessage>>
+ handlers = new ConcurrentHashMap<>();
+ private final Transport transport;
+ private final Link<ControlMessage.DriverToClientMessage> link;
+ private volatile boolean isClosed = false;
+
+ @Inject
+ private ClientRPC(final TransportFactory transportFactory,
+ final LocalAddressProvider localAddressProvider,
+ @Parameter(JobConf.ClientSideRPCServerHost.class) final String clientHost,
+ @Parameter(JobConf.ClientSideRPCServerPort.class) final int clientPort) throws IOException {
+ transport = transportFactory.newInstance(localAddressProvider.getLocalAddress(),
+ 0, new SyncStage<>(new RPCEventHandler()), null, RETRY_COUNT, RETRY_TIMEOUT);
+ final SocketAddress clientAddress = new InetSocketAddress(clientHost, clientPort);
+ link = transport.open(clientAddress, ENCODER, LINK_LISTENER);
+ }
+
+ /**
+ * Registers handler for the given type of message.
+ * @param type the type of message
+ * @param handler handler implementation
+ * @return {@code this}
+ */
+ public ClientRPC registerHandler(final ControlMessage.ClientToDriverMessageType type,
+ final EventHandler<ControlMessage.ClientToDriverMessage> handler) {
+ if (handlers.putIfAbsent(type, handler) != null) {
+ throw new RuntimeException(String.format("A handler for %s already registered", type));
+ }
+ return this;
+ }
+
+ /**
+ * Shuts down the transport.
+ */
+ public void shutdown() {
+ ensureRunning();
+ try {
+ transport.close();
+ } catch (final Exception e) {
+ throw new RuntimeException(e);
+ } finally {
+ isClosed = true;
+ }
+ }
+
+ /**
+ * Write message to client.
+ * @param message message to send.
+ */
+ public void send(final ControlMessage.DriverToClientMessage message) {
+ ensureRunning();
+ link.write(message);
+ }
+
+ /**
+ * Handles message from client.
+ * @param message message to process
+ */
+ private void handleMessage(final ControlMessage.ClientToDriverMessage message) {
+ final ControlMessage.ClientToDriverMessageType type = message.getType();
+ final EventHandler<ControlMessage.ClientToDriverMessage> handler = handlers.get(type);
+ if (handler == null) {
+ throw new RuntimeException(String.format("Handler for message type %s not registered", type));
+ } else {
+ handler.onNext(message);
+ }
+ }
+
+ /**
+ * Provides event handler for messages from client.
+ */
+ private final class RPCEventHandler implements EventHandler<TransportEvent> {
+ @Override
+ public void onNext(final TransportEvent transportEvent) {
+ try {
+ final byte[] data = transportEvent.getData();
+ final ControlMessage.ClientToDriverMessage message = ControlMessage.ClientToDriverMessage.parseFrom(data);
+ handleMessage(message);
+ } catch (final InvalidProtocolBufferException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ }
+
+ /**
+ * Ensure the Transport is running.
+ */
+ private void ensureRunning() {
+ if (isClosed) {
+ throw new RuntimeException("The ClientRPC is already closed");
+ }
+ }
+
+ /**
+ * Provides encoder for {@link edu.snu.nemo.runtime.common.comm.ControlMessage.DriverToClientMessage}.
+ */
+ private static final class DriverToClientMessageEncoder implements Encoder<ControlMessage.DriverToClientMessage> {
+ @Override
+ public byte[] encode(final ControlMessage.DriverToClientMessage driverToClientMessage) {
+ return driverToClientMessage.toByteArray();
+ }
+ }
+
+ /**
+ * Provides {@link LinkListener}.
+ */
+ private static final class ClientRPCLinkListener implements LinkListener<ControlMessage.DriverToClientMessage> {
+
+ @Override
+ public void onSuccess(final ControlMessage.DriverToClientMessage driverToClientMessage) {
+ }
+
+ @Override
+ public void onException(final Throwable throwable,
+ final SocketAddress socketAddress,
+ final ControlMessage.DriverToClientMessage driverToClientMessage) {
+ throw new RuntimeException(throwable);
+ }
+ }
+}
diff --git a/runtime/driver/src/main/java/edu/snu/nemo/driver/NemoDriver.java b/runtime/driver/src/main/java/edu/snu/nemo/driver/NemoDriver.java
index e2e263c..8102d64 100644
--- a/runtime/driver/src/main/java/edu/snu/nemo/driver/NemoDriver.java
+++ b/runtime/driver/src/main/java/edu/snu/nemo/driver/NemoDriver.java
@@ -18,6 +18,7 @@ package edu.snu.nemo.driver;
import edu.snu.nemo.common.ir.IdManager;
import edu.snu.nemo.conf.JobConf;
import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.comm.ControlMessage;
import edu.snu.nemo.runtime.common.message.MessageParameters;
import edu.snu.nemo.runtime.master.RuntimeMaster;
import org.apache.reef.annotations.audience.DriverSide;
@@ -67,9 +68,9 @@ public final class NemoDriver {
private final String jobId;
private final String localDirectory;
private final String glusterDirectory;
+ private final ClientRPC clientRPC;
// Client for sending log messages
- private final JobMessageObserver client;
private final RemoteClientMessageLoggingHandler handler;
@Inject
@@ -78,6 +79,7 @@ public final class NemoDriver {
final NameServer nameServer,
final LocalAddressProvider localAddressProvider,
final JobMessageObserver client,
+ final ClientRPC clientRPC,
@Parameter(JobConf.ExecutorJsonContents.class) final String resourceSpecificationString,
@Parameter(JobConf.JobId.class) final String jobId,
@Parameter(JobConf.FileDirectory.class) final String localDirectory,
@@ -91,8 +93,13 @@ public final class NemoDriver {
this.jobId = jobId;
this.localDirectory = localDirectory;
this.glusterDirectory = glusterDirectory;
- this.client = client;
this.handler = new RemoteClientMessageLoggingHandler(client);
+ this.clientRPC = clientRPC;
+ clientRPC.registerHandler(ControlMessage.ClientToDriverMessageType.LaunchDAG,
+ message -> startSchedulingUserApplication(message.getLaunchDAG().getDag()));
+ // Send DriverStarted message to the client
+ clientRPC.send(ControlMessage.DriverToClientMessage.newBuilder()
+ .setType(ControlMessage.DriverToClientMessageType.DriverStarted).build());
}
/**
@@ -135,15 +142,19 @@ public final class NemoDriver {
final boolean finalExecutorLaunched = runtimeMaster.onExecutorLaunched(activeContext);
if (finalExecutorLaunched) {
- startSchedulingUserApplication();
+ clientRPC.send(ControlMessage.DriverToClientMessage.newBuilder()
+ .setType(ControlMessage.DriverToClientMessageType.ResourceReady).build());
}
}
}
- private void startSchedulingUserApplication() {
+ /**
+ * Start user application.
+ */
+ public void startSchedulingUserApplication(final String dagString) {
// Launch user application (with a new thread)
final ExecutorService userApplicationRunnerThread = Executors.newSingleThreadExecutor();
- userApplicationRunnerThread.execute(userApplicationRunner);
+ userApplicationRunnerThread.execute(() -> userApplicationRunner.run(dagString));
userApplicationRunnerThread.shutdown();
}
@@ -175,6 +186,7 @@ public final class NemoDriver {
@Override
public void onNext(final StopTime stopTime) {
handler.close();
+ clientRPC.shutdown();
}
}
diff --git a/runtime/driver/src/main/java/edu/snu/nemo/driver/UserApplicationRunner.java b/runtime/driver/src/main/java/edu/snu/nemo/driver/UserApplicationRunner.java
index 3e415d8..6ba615f 100644
--- a/runtime/driver/src/main/java/edu/snu/nemo/driver/UserApplicationRunner.java
+++ b/runtime/driver/src/main/java/edu/snu/nemo/driver/UserApplicationRunner.java
@@ -42,11 +42,10 @@ import java.util.concurrent.ScheduledExecutorService;
/**
* Compiles and runs User application.
*/
-public final class UserApplicationRunner implements Runnable {
+public final class UserApplicationRunner {
private static final Logger LOG = LoggerFactory.getLogger(UserApplicationRunner.class.getName());
private final String dagDirectory;
- private final String dagString;
private final String optimizationPolicyCanonicalName;
private final int maxScheduleAttempt;
@@ -58,14 +57,12 @@ public final class UserApplicationRunner implements Runnable {
@Inject
private UserApplicationRunner(@Parameter(JobConf.DAGDirectory.class) final String dagDirectory,
- @Parameter(JobConf.SerializedDAG.class) final String dagString,
@Parameter(JobConf.OptimizationPolicy.class) final String optimizationPolicy,
@Parameter(JobConf.MaxScheduleAttempt.class) final int maxScheduleAttempt,
final PubSubEventHandlerWrapper pubSubEventHandlerWrapper,
final Injector injector,
final RuntimeMaster runtimeMaster) {
this.dagDirectory = dagDirectory;
- this.dagString = dagString;
this.optimizationPolicyCanonicalName = optimizationPolicy;
this.maxScheduleAttempt = maxScheduleAttempt;
this.injector = injector;
@@ -74,8 +71,14 @@ public final class UserApplicationRunner implements Runnable {
this.pubSubWrapper = pubSubEventHandlerWrapper;
}
- @Override
- public void run() {
+ /**
+ * Run the user program submitted by Nemo Client.
+ * Specifically, deserialize DAG from Client, optimize it, generate physical plan,
+ * and tell {@link RuntimeMaster} to execute the plan.
+ *
+ * @param dagString Serialized IR DAG from Nemo Client.
+ */
+ public void run(final String dagString) {
try {
LOG.info("##### Nemo Compiler #####");
diff --git a/tests/src/test/java/edu/snu/nemo/client/ClientDriverRPCTest.java b/tests/src/test/java/edu/snu/nemo/client/ClientDriverRPCTest.java
new file mode 100644
index 0000000..af94b59
--- /dev/null
+++ b/tests/src/test/java/edu/snu/nemo/client/ClientDriverRPCTest.java
@@ -0,0 +1,97 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.client;
+
+import edu.snu.nemo.driver.ClientRPC;
+import edu.snu.nemo.runtime.common.comm.ControlMessage;
+import org.apache.reef.tang.Injector;
+import org.apache.reef.tang.Tang;
+import org.apache.reef.tang.exceptions.InjectionException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.concurrent.CountDownLatch;
+
+/**
+ * Test for communication between {@link DriverRPCServer} and {@link ClientRPC}.
+ */
+public final class ClientDriverRPCTest {
+ private DriverRPCServer driverRPCServer;
+ private ClientRPC clientRPC;
+ @Before
+ public void setupDriverRPCServer() {
+ // Initialize DriverRPCServer.
+ driverRPCServer = new DriverRPCServer();
+ }
+
+ private void setupClientRPC() throws InjectionException {
+ driverRPCServer.run();
+ final Injector clientRPCInjector = Tang.Factory.getTang().newInjector(driverRPCServer.getListeningConfiguration());
+ clientRPC = clientRPCInjector.getInstance(ClientRPC.class);
+ }
+
+ @After
+ public void cleanup() {
+ driverRPCServer.shutdown();
+ clientRPC.shutdown();
+ }
+
+ /**
+ * Test with empty set of handlers.
+ * @throws InjectionException on Exceptions on creating {@link ClientRPC}.
+ */
+ @Test
+ public void testRPCSetup() throws InjectionException {
+ setupClientRPC();
+ }
+
+ /**
+ * Test with basic request method from driver to client.
+ * @throws InjectionException on Exceptions on creating {@link ClientRPC}.
+ * @throws InterruptedException when interrupted while waiting EventHandler invocation
+ */
+ @Test
+ public void testDriverToClientMethodInvocation() throws InjectionException, InterruptedException {
+ final CountDownLatch latch = new CountDownLatch(1);
+ driverRPCServer.registerHandler(ControlMessage.DriverToClientMessageType.DriverStarted,
+ msg -> latch.countDown());
+ setupClientRPC();
+ clientRPC.send(ControlMessage.DriverToClientMessage.newBuilder()
+ .setType(ControlMessage.DriverToClientMessageType.DriverStarted).build());
+ latch.await();
+ }
+
+ /**
+ * Test with request-response RPC between client and driver.
+ * @throws InjectionException on Exceptions on creating {@link ClientRPC}.
+ * @throws InterruptedException when interrupted while waiting EventHandler invocation
+ */
+ @Test
+ public void testBetweenClientAndDriver() throws InjectionException, InterruptedException {
+ final CountDownLatch latch = new CountDownLatch(1);
+ driverRPCServer.registerHandler(ControlMessage.DriverToClientMessageType.DriverStarted,
+ msg -> driverRPCServer.send(ControlMessage.ClientToDriverMessage.newBuilder()
+ .setType(ControlMessage.ClientToDriverMessageType.LaunchDAG)
+ .setLaunchDAG(ControlMessage.LaunchDAGMessage.newBuilder().setDag("").build())
+ .build()));
+ setupClientRPC();
+ clientRPC.registerHandler(ControlMessage.ClientToDriverMessageType.LaunchDAG, msg -> latch.countDown());
+ clientRPC.send(ControlMessage.DriverToClientMessage.newBuilder()
+ .setType(ControlMessage.DriverToClientMessageType.DriverStarted).build());
+ latch.await();
+ }
+}