You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nemo.apache.org by wy...@apache.org on 2019/03/30 15:17:05 UTC
[incubator-nemo] branch master updated: [NEMO-360] Implementing an
'XGBoostPolicy' (#203)
This is an automated email from the ASF dual-hosted git repository.
wylee pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-nemo.git
The following commit(s) were added to refs/heads/master by this push:
new fbd2e6e [NEMO-360] Implementing an 'XGBoostPolicy' (#203)
fbd2e6e is described below
commit fbd2e6e16fea969e0e452f3d66e1691b08e62454
Author: Won Wook SONG <wo...@apache.org>
AuthorDate: Sun Mar 31 00:17:01 2019 +0900
[NEMO-360] Implementing an 'XGBoostPolicy' (#203)
JIRA: [NEMO-360: Implementing an 'XGBoostPolicy'](https://issues.apache.org/jira/projects/NEMO/issues/NEMO-360)
**Major changes:**
- Adds a python script that runs XGBoost.
- as well as a bash script for using the python script.
- Uses the Client <-> Driver RPC to run the XGBoost script and return the results to the driver.
- Enables environment tag.
- Refactors the utility classes accordingly.
**Minor changes to note:**
- Miscellaneous methods to help the implementation (e.g., getEdgeById)
- Consistency in the shell scripts
- Appropriately adds files to the gitignore
- Javadocs
**Tests for the changes:**
- MetricUtilsTest confirms that changing an EP to indices and back to EP works correctly (TDD)
**Other comments:**
- None
Closes #203
---
.gitignore | 7 +
bin/run_beam.sh | 2 +-
bin/run_nexmark.sh | 2 +-
bin/run_spark.sh | 6 +-
bin/{run_spark.sh => xgboost_optimization.sh} | 10 +-
client/pom.xml | 6 +-
.../java/org/apache/nemo/client/ClientUtils.java | 95 +++++
.../java/org/apache/nemo/client/JobLauncher.java | 5 +
.../apache/nemo/client/ClientDriverRPCTest.java | 2 +-
common/pom.xml | 1 -
.../src/main/java/org/apache/nemo/common/Util.java | 29 +-
.../main/java/org/apache/nemo/common/dag/DAG.java | 8 +
.../org/apache/nemo/common/dag/DAGInterface.java | 12 +-
.../exception/IllegalEdgeOperationException.java | 9 +
.../main/java/org/apache/nemo/common/ir/IRDAG.java | 36 +-
.../executionproperty/PartitionerProperty.java | 11 +
.../org/apache/nemo/common/ir/vertex/IRVertex.java | 4 +
.../IgnoreSchedulingTempDataReceiverProperty.java | 17 +-
.../ResourceLocalityProperty.java | 4 +-
.../executionproperty/ResourceSlotProperty.java | 4 +-
.../apache/nemo/common/{util => }/UtilTest.java | 13 +-
compiler/frontend/beam/pom.xml | 2 -
compiler/optimizer/pom.xml | 5 +
.../nemo/compiler/optimizer/NemoOptimizer.java | 39 ++-
.../nemo/compiler/optimizer/OptimizerUtils.java | 86 +++++
.../pass/compiletime/annotating/XGBoostPass.java | 136 ++++++++
.../compiler/optimizer/policy/XGBoostPolicy.java | 54 +++
.../main/java/org/apache/nemo/conf/JobConf.java | 9 +-
ml/nemo_xgboost_optimization.py | 385 +++++++++++++++++++++
ml/requirements.txt | 9 +
pom.xml | 4 +
runtime/common/pom.xml | 7 +
.../nemo/runtime/common/message}/ClientRPC.java | 2 +-
.../nemo/runtime/common/metric/MetricUtils.java | 337 +++++++++++++-----
runtime/common/src/main/proto/ControlMessage.proto | 20 +-
.../runtime/common/metric/MetricUtilsTest.java | 155 +++++++++
.../java/org/apache/nemo/driver/NemoDriver.java | 19 +-
.../apache/nemo/driver/UserApplicationRunner.java | 9 +
.../executor/datatransfer/DataTransferTest.java | 2 +-
.../apache/nemo/runtime/master/RuntimeMaster.java | 27 +-
.../nemo/runtime/master/metric/MetricStore.java | 46 ++-
41 files changed, 1470 insertions(+), 166 deletions(-)
diff --git a/.gitignore b/.gitignore
index 5497604..4d3db04 100644
--- a/.gitignore
+++ b/.gitignore
@@ -56,6 +56,13 @@ MetricsData
.temp*
#
# ----------------------------------------------------------------------
+# ML Files
+# ----------------------------------------------------------------------
+ml/*.out
+ml/*.model
+venv/*
+#
+# ----------------------------------------------------------------------
# Unknown Files. Please clean up over time
# ----------------------------------------------------------------------
ml-data
diff --git a/bin/run_beam.sh b/bin/run_beam.sh
index 59b5a67..ddd0c75 100755
--- a/bin/run_beam.sh
+++ b/bin/run_beam.sh
@@ -21,4 +21,4 @@ VERSION=$(mvn -q \
-Dexec.executable=echo -Dexec.args='${project.version}' \
--non-recursive exec:exec)
-java -Dlog4j.configuration=file://`pwd`/log4j.properties -cp client/target/nemo-client-$VERSION-shaded.jar:examples/beam/target/nemo-examples-beam-$VERSION-shaded.jar:`yarn classpath` org.apache.nemo.client.JobLauncher "$@"
+java -Dlog4j.configuration=file://`pwd`/log4j.properties -cp examples/beam/target/nemo-examples-beam-${VERSION}-shaded.jar:client/target/nemo-client-${VERSION}-shaded.jar:`yarn classpath` org.apache.nemo.client.JobLauncher "$@"
diff --git a/bin/run_nexmark.sh b/bin/run_nexmark.sh
index d1b6a2c..ce2c0ef 100755
--- a/bin/run_nexmark.sh
+++ b/bin/run_nexmark.sh
@@ -21,4 +21,4 @@ VERSION=$(mvn -q \
-Dexec.executable=echo -Dexec.args='${project.version}' \
--non-recursive exec:exec)
-java -Dlog4j.configuration=file://`pwd`/log4j.properties -cp client/target/nemo-client-$VERSION-shaded.jar:`yarn classpath`:examples/nexmark/target/nexmark-$VERSION-shaded.jar org.apache.nemo.client.JobLauncher "$@"
+java -Dlog4j.configuration=file://`pwd`/log4j.properties -cp examples/nexmark/target/nexmark-${VERSION}-shaded.jar:client/target/nemo-client-${VERSION}-shaded.jar:`yarn classpath` org.apache.nemo.client.JobLauncher "$@"
diff --git a/bin/run_spark.sh b/bin/run_spark.sh
index 314fd0d..b4a318d 100755
--- a/bin/run_spark.sh
+++ b/bin/run_spark.sh
@@ -17,6 +17,8 @@
# specific language governing permissions and limitations
# under the License.
-java -Dlog4j.configuration=file://`pwd`/log4j.properties -cp examples/spark/target/nemo-examples-spark-$(mvn -q \
+VERSION=$(mvn -q \
-Dexec.executable=echo -Dexec.args='${project.version}' \
- --non-recursive exec:exec)-shaded.jar:`yarn classpath` org.apache.nemo.client.JobLauncher "$@"
+ --non-recursive exec:exec)
+
+java -Dlog4j.configuration=file://`pwd`/log4j.properties -cp examples/spark/target/nemo-examples-spark-${VERSION}-shaded.jar:`yarn classpath` org.apache.nemo.client.JobLauncher "$@"
diff --git a/bin/run_spark.sh b/bin/xgboost_optimization.sh
similarity index 74%
copy from bin/run_spark.sh
copy to bin/xgboost_optimization.sh
index 314fd0d..ed141f8 100755
--- a/bin/run_spark.sh
+++ b/bin/xgboost_optimization.sh
@@ -17,6 +17,10 @@
# specific language governing permissions and limitations
# under the License.
-java -Dlog4j.configuration=file://`pwd`/log4j.properties -cp examples/spark/target/nemo-examples-spark-$(mvn -q \
- -Dexec.executable=echo -Dexec.args='${project.version}' \
- --non-recursive exec:exec)-shaded.jar:`yarn classpath` org.apache.nemo.client.JobLauncher "$@"
+echo "You should already have python3 installed"
+echo "Usage: ./bin/nemo_xgboost_optimization.sh <tablename>"
+pushd ml
+touch results.out
+pip3 install -r requirements.txt
+python3 nemo_xgboost_optimization.py -t "$@"
+popd
diff --git a/client/pom.xml b/client/pom.xml
index 367171b..a2eba1f 100644
--- a/client/pom.xml
+++ b/client/pom.xml
@@ -57,8 +57,6 @@ under the License.
<artifactId>reef-runtime-yarn</artifactId>
<version>${reef.version}</version>
</dependency>
-
-
<!-- for nemo-beam-runner -->
<dependency>
<groupId>org.apache.nemo</groupId>
@@ -73,7 +71,6 @@ under the License.
</dependency>
</dependencies>
-
<build>
<plugins>
<plugin>
@@ -99,8 +96,7 @@ under the License.
${project.build.directory}/${project.artifactId}-${project.version}-shaded.jar
</outputFile>
<transformers>
- <!-- Required for using beam-hadoop: See https://stackoverflow.com/questions/44365545
- -->
+ <!-- Required for using beam-hadoop: See https://stackoverflow.com/questions/44365545-->
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer"/>
</transformers>
</configuration>
diff --git a/client/src/main/java/org/apache/nemo/client/ClientUtils.java b/client/src/main/java/org/apache/nemo/client/ClientUtils.java
new file mode 100644
index 0000000..a200a2e
--- /dev/null
+++ b/client/src/main/java/org/apache/nemo/client/ClientUtils.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.nemo.client;
+
+import org.apache.commons.io.FileUtils;
+import org.apache.nemo.common.Util;
+import org.apache.nemo.common.exception.MetricException;
+import org.apache.nemo.runtime.common.comm.ControlMessage;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+
+/**
+ * Utility class for the Client.
+ */
+public final class ClientUtils {
+ private static final Logger LOG = LoggerFactory.getLogger(ClientUtils.class.getName());
+
+ /**
+ * Private constructor.
+ */
+ private ClientUtils() {
+ }
+
+ /**
+ * Handler for the launch optimization message.
+ *
+ * @param message the message received from the driver.
+ */
+ static void handleOptimizationType(final ControlMessage.DriverToClientMessage message,
+ final DriverRPCServer driverRPCServer) {
+ switch (message.getOptimizationType()) {
+ case XGBoost:
+ new Thread(() ->
+ driverRPCServer.send(ControlMessage.ClientToDriverMessage.newBuilder()
+ .setType(ControlMessage.ClientToDriverMessageType.Notification)
+ .setMessage(ControlMessage.NotificationMessage.newBuilder()
+ .setOptimizationType(ControlMessage.OptimizationType.XGBoost)
+ .setData(ClientUtils.launchXGBoostScript(message.getDataCollected().getData()))
+ .build())
+ .build()))
+ .start();
+ break;
+ default:
+ break;
+ }
+ }
+
+ /**
+ * launches the XGBoost Script.
+ *
+ * @param irDagSummary the IR DAG to run the script for.
+ * @return the results file converted into string.
+ */
+ private static String launchXGBoostScript(final String irDagSummary) {
+ try {
+ final String projectRootPath = Util.fetchProjectRootPath();
+ final String scriptPath = projectRootPath + "/bin/xgboost_optimization.sh";
+ // It trains the model with the metric data of previous jobs with the same IRDAG signature.
+ final String[] command = {scriptPath, irDagSummary};
+ LOG.info("Running the python script at {}", scriptPath);
+ final ProcessBuilder builder = new ProcessBuilder(command);
+ builder.directory(new File(projectRootPath));
+ builder.redirectOutput(ProcessBuilder.Redirect.INHERIT);
+ builder.redirectError(ProcessBuilder.Redirect.INHERIT);
+ final Process process = builder.start();
+ process.waitFor();
+ LOG.info("Python script execution complete!");
+
+ final String resultsFile = projectRootPath + "/ml/results.out";
+ LOG.info("Reading the results of the script at {}", resultsFile);
+ return FileUtils.readFileToString(new File(resultsFile), "UTF-8");
+ } catch (Exception e) {
+ throw new MetricException(e);
+ }
+ }
+}
diff --git a/client/src/main/java/org/apache/nemo/client/JobLauncher.java b/client/src/main/java/org/apache/nemo/client/JobLauncher.java
index 557bef0..6725e65 100644
--- a/client/src/main/java/org/apache/nemo/client/JobLauncher.java
+++ b/client/src/main/java/org/apache/nemo/client/JobLauncher.java
@@ -21,6 +21,7 @@ package org.apache.nemo.client;
import com.google.common.annotations.VisibleForTesting;
import com.google.protobuf.ByteString;
import org.apache.commons.lang3.SerializationUtils;
+import org.apache.nemo.common.Util;
import org.apache.nemo.common.ir.IRDAG;
import org.apache.nemo.compiler.backend.nemo.NemoPlanRewriter;
import org.apache.nemo.conf.JobConf;
@@ -122,6 +123,7 @@ public final class JobLauncher {
*/
public static void setup(final String[] args) throws InjectionException, ClassNotFoundException, IOException {
// Get Job and Driver Confs
+ LOG.info("Project Root Path: {}", Util.fetchProjectRootPath());
builtJobConf = getJobConf(args);
// Registers actions for launching the DAG.
@@ -134,6 +136,8 @@ public final class JobLauncher {
.registerHandler(ControlMessage.DriverToClientMessageType.ExecutionDone, event -> jobDoneLatch.countDown())
.registerHandler(ControlMessage.DriverToClientMessageType.DataCollected, message -> COLLECTED_DATA.addAll(
SerializationUtils.deserialize(Base64.getDecoder().decode(message.getDataCollected().getData()))))
+ .registerHandler(ControlMessage.DriverToClientMessageType.LaunchOptimization, message ->
+ ClientUtils.handleOptimizationType(message, driverRPCServer))
.run();
final Configuration driverConf = getDriverConf(builtJobConf);
@@ -251,6 +255,7 @@ public final class JobLauncher {
LOG.info("Launching DAG...");
serializedDAG = Base64.getEncoder().encodeToString(SerializationUtils.serialize(dag));
jobDoneLatch = new CountDownLatch(1);
+
driverRPCServer.send(ControlMessage.ClientToDriverMessage.newBuilder()
.setType(ControlMessage.ClientToDriverMessageType.LaunchDAG)
.setLaunchDAG(ControlMessage.LaunchDAGMessage.newBuilder()
diff --git a/client/src/test/java/org/apache/nemo/client/ClientDriverRPCTest.java b/client/src/test/java/org/apache/nemo/client/ClientDriverRPCTest.java
index 24522a8..93db033 100644
--- a/client/src/test/java/org/apache/nemo/client/ClientDriverRPCTest.java
+++ b/client/src/test/java/org/apache/nemo/client/ClientDriverRPCTest.java
@@ -19,7 +19,7 @@
package org.apache.nemo.client;
import org.apache.nemo.runtime.common.comm.ControlMessage;
-import org.apache.nemo.runtime.master.ClientRPC;
+import org.apache.nemo.runtime.common.message.ClientRPC;
import org.apache.reef.tang.Injector;
import org.apache.reef.tang.Tang;
import org.apache.reef.tang.exceptions.InjectionException;
diff --git a/common/pom.xml b/common/pom.xml
index a325148..65888f8 100644
--- a/common/pom.xml
+++ b/common/pom.xml
@@ -53,6 +53,5 @@ under the License.
<artifactId>beam-sdks-java-core</artifactId>
<version>${beam.version}</version>
</dependency>
-
</dependencies>
</project>
diff --git a/common/src/main/java/org/apache/nemo/common/Util.java b/common/src/main/java/org/apache/nemo/common/Util.java
index 0778238..7a84491 100644
--- a/common/src/main/java/org/apache/nemo/common/Util.java
+++ b/common/src/main/java/org/apache/nemo/common/Util.java
@@ -43,7 +43,7 @@ import java.util.stream.Stream;
*/
public final class Util {
// Assume that this tag is never used in user application
- public static final String CONTROL_EDGE_TAG = "CONTROL_EDGE";
+ private static final String CONTROL_EDGE_TAG = "CONTROL_EDGE";
private static Instrumentation instrumentation;
@@ -59,7 +59,12 @@ public final class Util {
* @return the project root path.
*/
public static String fetchProjectRootPath() {
- return recursivelyFindLicense(Paths.get(System.getProperty("user.dir")));
+ final String nemoHome = System.getenv("NEMO_HOME");
+ if (nemoHome != null && !nemoHome.isEmpty()) {
+ return nemoHome;
+ } else {
+ return recursivelyFindLicense(Paths.get(System.getProperty("user.dir")));
+ }
}
/**
@@ -205,6 +210,26 @@ public final class Util {
}
/**
+ * Method to restore String ID from the numeric ID.
+ *
+ * @param numericId the numeric id.
+ * @return the restored string ID.
+ */
+ public static String restoreVertexId(final Integer numericId) {
+ return "vertex" + numericId;
+ }
+
+ /**
+ * Method to restore String ID from the numeric ID.
+ *
+ * @param numericId the numeric id.
+ * @return the restored string ID.
+ */
+ public static String restoreEdgeId(final Integer numericId) {
+ return "edge" + numericId;
+ }
+
+ /**
* Method for the instrumentation: for getting the object size.
*
* @param args arguments.
diff --git a/common/src/main/java/org/apache/nemo/common/dag/DAG.java b/common/src/main/java/org/apache/nemo/common/dag/DAG.java
index 537d681..004e2ad 100644
--- a/common/src/main/java/org/apache/nemo/common/dag/DAG.java
+++ b/common/src/main/java/org/apache/nemo/common/dag/DAG.java
@@ -100,6 +100,14 @@ public final class DAG<V extends Vertex, E extends Edge<V>> implements DAGInterf
}
@Override
+ public E getEdgeById(final String id) {
+ return incomingEdges.values().stream().flatMap(List::stream)
+ .filter(e -> e.getId().equals(id))
+ .findFirst()
+ .orElseThrow(() -> new IllegalEdgeOperationException("There is no edge of id: " + id));
+ }
+
+ @Override
public List<V> getVertices() {
return vertices;
}
diff --git a/common/src/main/java/org/apache/nemo/common/dag/DAGInterface.java b/common/src/main/java/org/apache/nemo/common/dag/DAGInterface.java
index fe07192..c85300f 100644
--- a/common/src/main/java/org/apache/nemo/common/dag/DAGInterface.java
+++ b/common/src/main/java/org/apache/nemo/common/dag/DAGInterface.java
@@ -39,12 +39,20 @@ public interface DAGInterface<V extends Vertex, E extends Edge<V>> extends Seria
/**
* Retrieves the vertex given its ID.
*
- * @param id of the vertex to retrieve
- * @return the vertex
+ * @param id of the vertex to retrieve.
+ * @return the vertex.
*/
V getVertexById(String id);
/**
+ * Retrieves the edge given its ID.
+ *
+ * @param id of the edge to retrieve.
+ * @return the edge.
+ */
+ E getEdgeById(String id);
+
+ /**
* Retrieves the vertices of this DAG.
*
* @return the list of vertices.
diff --git a/common/src/main/java/org/apache/nemo/common/exception/IllegalEdgeOperationException.java b/common/src/main/java/org/apache/nemo/common/exception/IllegalEdgeOperationException.java
index 2c9585a..d9c79be 100644
--- a/common/src/main/java/org/apache/nemo/common/exception/IllegalEdgeOperationException.java
+++ b/common/src/main/java/org/apache/nemo/common/exception/IllegalEdgeOperationException.java
@@ -33,4 +33,13 @@ public final class IllegalEdgeOperationException extends RuntimeException {
public IllegalEdgeOperationException(final Throwable cause) {
super(cause);
}
+
+ /**
+ * IllegalEdgeOperationException.
+ *
+ * @param message message.
+ */
+ public IllegalEdgeOperationException(final String message) {
+ super(message);
+ }
}
diff --git a/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java b/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java
index 29af0a2..dfc1103 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java
@@ -20,8 +20,7 @@ package org.apache.nemo.common.ir;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.common.collect.Sets;
-import org.apache.nemo.common.KeyExtractor;
-import org.apache.nemo.common.Pair;
+import org.apache.nemo.common.PairKeyExtractor;
import org.apache.nemo.common.Util;
import org.apache.nemo.common.coder.BytesDecoderFactory;
import org.apache.nemo.common.coder.BytesEncoderFactory;
@@ -111,8 +110,16 @@ public final class IRDAG implements DAGInterface<IRVertex, IREdge> {
return canAdvance;
}
+ /**
+ * @return a IR DAG summary string, consisting of only the vertices generated from the frontend.
+ */
public String irDAGSummary() {
- return "RV" + getRootVertices().size() + "_V" + getVertices().size() + "_E" + getVertices().stream()
+ return "rv" + getRootVertices().size()
+ + "_v" + getVertices().stream()
+ .filter(v -> !v.isUtilityVertex()) // Exclude utility vertices
+ .count()
+ + "_e" + getVertices().stream()
+ .filter(v -> !v.isUtilityVertex()) // Exclude utility vertices
.mapToInt(v -> getIncomingEdgesOf(v).size())
.sum();
}
@@ -200,15 +207,12 @@ public final class IRDAG implements DAGInterface<IRVertex, IREdge> {
modifiedDAG.getOutgoingEdgesOf(vertexToDelete).stream()
.filter(e -> !Util.isControlEdge(e))
.map(IREdge::getDst)
- .forEach(dstVertex -> {
+ .forEach(dstVertex ->
modifiedDAG.getIncomingEdgesOf(vertexToDelete).stream()
.filter(e -> !Util.isControlEdge(e))
.map(IREdge::getSrc)
- .forEach(srcVertex -> {
- builder.connectVertices(
- Util.cloneEdge(streamVertexToOriginalEdge.get(vertexToDelete), srcVertex, dstVertex));
- });
- });
+ .forEach(srcVertex -> builder.connectVertices(
+ Util.cloneEdge(streamVertexToOriginalEdge.get(vertexToDelete), srcVertex, dstVertex))));
modifiedDAG = builder.buildWithoutSourceSinkCheck();
} else if (vertexToDelete instanceof MessageAggregatorVertex || vertexToDelete instanceof MessageBarrierVertex) {
modifiedDAG = rebuildExcluding(modifiedDAG, vertexGroupToDelete).buildWithoutSourceSinkCheck();
@@ -584,16 +588,9 @@ public final class IRDAG implements DAGInterface<IRVertex, IREdge> {
newEdge.setProperty(DataStoreProperty.of(DataStoreProperty.Value.LocalFileStore));
newEdge.setProperty(DataPersistenceProperty.of(DataPersistenceProperty.Value.Keep));
newEdge.setProperty(DataFlowProperty.of(DataFlowProperty.Value.Push));
- final KeyExtractor pairKeyExtractor = (element) -> {
- if (element instanceof Pair) {
- return ((Pair) element).left();
- } else {
- throw new IllegalStateException(element.toString());
- }
- };
newEdge.setPropertyPermanently(encoder);
newEdge.setPropertyPermanently(decoder);
- newEdge.setPropertyPermanently(KeyExtractorProperty.of(pairKeyExtractor));
+ newEdge.setPropertyPermanently(KeyExtractorProperty.of(new PairKeyExtractor()));
// TODO #345: Simplify insert(MessageBarrierVertex)
// these are obviously wrong, but hacks for now...
@@ -659,6 +656,11 @@ public final class IRDAG implements DAGInterface<IRVertex, IREdge> {
}
@Override
+ public IREdge getEdgeById(final String id) {
+ return modifiedDAG.getEdgeById(id);
+ }
+
+ @Override
public List<IRVertex> getVertices() {
return modifiedDAG.getVertices();
}
diff --git a/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/PartitionerProperty.java b/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/PartitionerProperty.java
index 4da595a..69f27e4 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/PartitionerProperty.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/PartitionerProperty.java
@@ -70,6 +70,17 @@ public final class PartitionerProperty
}
/**
+ * Static constructor.
+ * This is used by reflection by the MetricUtils class.
+ *
+ * @param value the Pair value.
+ * @return the new execution property.
+ */
+ public static PartitionerProperty of(final Pair<Type, Integer> value) {
+ return new PartitionerProperty(value);
+ }
+
+ /**
* Partitioning types.
*/
public enum Type {
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/IRVertex.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/IRVertex.java
index 4a3c6df..b8f8f24 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/IRVertex.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/IRVertex.java
@@ -86,6 +86,10 @@ public abstract class IRVertex extends Vertex implements Cloneable<IRVertex> {
return this;
}
+ public final Boolean isUtilityVertex() {
+ return this.getClass().getPackage().getName().startsWith("org.apache.nemo.common.ir.vertex.utility.");
+ }
+
/**
* Get the executionProperty of the IRVertex.
*
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/IgnoreSchedulingTempDataReceiverProperty.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/IgnoreSchedulingTempDataReceiverProperty.java
index 37ab317..e11e98d 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/IgnoreSchedulingTempDataReceiverProperty.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/IgnoreSchedulingTempDataReceiverProperty.java
@@ -32,13 +32,13 @@ import org.apache.nemo.common.ir.executionproperty.VertexExecutionProperty;
public final class IgnoreSchedulingTempDataReceiverProperty extends VertexExecutionProperty<Boolean> {
private static final IgnoreSchedulingTempDataReceiverProperty IGNORE_SCHEDULING_TEMP_DATA_RECEIVER_PROPERTY =
- new IgnoreSchedulingTempDataReceiverProperty();
+ new IgnoreSchedulingTempDataReceiverProperty(true);
/**
* Constructor.
*/
- private IgnoreSchedulingTempDataReceiverProperty() {
- super(true);
+ private IgnoreSchedulingTempDataReceiverProperty(final Boolean value) {
+ super(value);
}
/**
@@ -49,4 +49,15 @@ public final class IgnoreSchedulingTempDataReceiverProperty extends VertexExecut
public static IgnoreSchedulingTempDataReceiverProperty of() {
return IGNORE_SCHEDULING_TEMP_DATA_RECEIVER_PROPERTY;
}
+
+ /**
+ * Static method exposing the constructor.
+ * This is used by reflection by the MetricUtils class.
+ *
+ * @param value the boolean value. This is always true by default for this property.
+ * @return the new execution property.
+ */
+ public static IgnoreSchedulingTempDataReceiverProperty of(final Boolean value) {
+ return new IgnoreSchedulingTempDataReceiverProperty(value);
+ }
}
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/ResourceLocalityProperty.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/ResourceLocalityProperty.java
index 63332b6..bc6aaa7 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/ResourceLocalityProperty.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/ResourceLocalityProperty.java
@@ -33,7 +33,7 @@ public final class ResourceLocalityProperty extends VertexExecutionProperty<Bool
*
* @param value value of the ExecutionProperty
*/
- private ResourceLocalityProperty(final boolean value) {
+ private ResourceLocalityProperty(final Boolean value) {
super(value);
}
@@ -43,7 +43,7 @@ public final class ResourceLocalityProperty extends VertexExecutionProperty<Bool
* @param value value of the new execution property
* @return the execution property
*/
- public static ResourceLocalityProperty of(final boolean value) {
+ public static ResourceLocalityProperty of(final Boolean value) {
return value ? SOURCE_TRUE : SOURCE_FALSE;
}
}
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/ResourceSlotProperty.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/ResourceSlotProperty.java
index 7475e21..a15acb3 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/ResourceSlotProperty.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/ResourceSlotProperty.java
@@ -33,7 +33,7 @@ public final class ResourceSlotProperty extends VertexExecutionProperty<Boolean>
*
* @param value value of the ExecutionProperty
*/
- private ResourceSlotProperty(final boolean value) {
+ private ResourceSlotProperty(final Boolean value) {
super(value);
}
@@ -43,7 +43,7 @@ public final class ResourceSlotProperty extends VertexExecutionProperty<Boolean>
* @param value value of the new execution property
* @return the execution property
*/
- public static ResourceSlotProperty of(final boolean value) {
+ public static ResourceSlotProperty of(final Boolean value) {
return value ? COMPLIANCE_TRUE : COMPLIANCE_FALSE;
}
}
diff --git a/common/src/test/java/org/apache/nemo/common/util/UtilTest.java b/common/src/test/java/org/apache/nemo/common/UtilTest.java
similarity index 76%
rename from common/src/test/java/org/apache/nemo/common/util/UtilTest.java
rename to common/src/test/java/org/apache/nemo/common/UtilTest.java
index 4e3d02d..01a51a2 100644
--- a/common/src/test/java/org/apache/nemo/common/util/UtilTest.java
+++ b/common/src/test/java/org/apache/nemo/common/UtilTest.java
@@ -16,16 +16,23 @@
* specific language governing permissions and limitations
* under the License.
*/
-package org.apache.nemo.common.util;
+package org.apache.nemo.common;
-import org.apache.nemo.common.Util;
+import org.junit.Assert;
import org.junit.Test;
+import java.nio.file.Paths;
import java.util.function.IntPredicate;
-import static org.junit.Assert.assertEquals;
+import static junit.framework.TestCase.assertEquals;
public class UtilTest {
+ @Test
+ public void testRootPath() {
+ final String one = Util.recursivelyFindLicense(Paths.get(System.getProperty("user.dir")));
+ final String two = Util.recursivelyFindLicense(Paths.get(System.getProperty("user.dir")).getParent());
+ Assert.assertEquals(one, two);
+ }
@Test
public void testCheckEqualityOfIntPredicates() {
diff --git a/compiler/frontend/beam/pom.xml b/compiler/frontend/beam/pom.xml
index 0c19975..ee63db7 100644
--- a/compiler/frontend/beam/pom.xml
+++ b/compiler/frontend/beam/pom.xml
@@ -61,7 +61,5 @@ under the License.
<scope>provided</scope>
</dependency>
<!-- https://mvnrepository.com/artifact/com.amazonaws/aws-java-sdk -->
-
-
</dependencies>
</project>
diff --git a/compiler/optimizer/pom.xml b/compiler/optimizer/pom.xml
index fb09507..bda16bc 100644
--- a/compiler/optimizer/pom.xml
+++ b/compiler/optimizer/pom.xml
@@ -48,6 +48,11 @@ under the License.
<version>${project.version}</version>
</dependency>
<dependency>
+ <groupId>org.apache.nemo</groupId>
+ <artifactId>nemo-runtime-common</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
<version>${jackson.version}</version>
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/NemoOptimizer.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/NemoOptimizer.java
index 1b5ab48..8c0da96 100644
--- a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/NemoOptimizer.java
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/NemoOptimizer.java
@@ -31,7 +31,10 @@ import org.apache.nemo.common.ir.vertex.executionproperty.IgnoreSchedulingTempDa
import org.apache.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
import org.apache.nemo.compiler.optimizer.pass.runtime.Message;
import org.apache.nemo.compiler.optimizer.policy.Policy;
+import org.apache.nemo.compiler.optimizer.policy.XGBoostPolicy;
import org.apache.nemo.conf.JobConf;
+import org.apache.nemo.runtime.common.comm.ControlMessage;
+import org.apache.nemo.runtime.common.message.ClientRPC;
import org.apache.reef.tang.annotations.Parameter;
import javax.inject.Inject;
@@ -46,19 +49,27 @@ import java.util.stream.Collectors;
public final class NemoOptimizer implements Optimizer {
private final String dagDirectory;
private final Policy optimizationPolicy;
+ private final String environmentTypeStr;
+ private final ClientRPC clientRPC;
private final Map<UUID, Integer> cacheIdToParallelism = new HashMap<>();
private int irDagCount = 0;
/**
- * @param dagDirectory to store JSON representation of intermediate DAGs.
- * @param policyName the name of the optimization policy.
+ * @param dagDirectory to store JSON representation of intermediate DAGs.
+ * @param policyName the name of the optimization policy.
+ * @param environmentTypeStr the environment type of the workload to optimize the DAG for.
+ * @param clientRPC the RPC channel to communicate with the client.
*/
@Inject
private NemoOptimizer(@Parameter(JobConf.DAGDirectory.class) final String dagDirectory,
- @Parameter(JobConf.OptimizationPolicy.class) final String policyName) {
+ @Parameter(JobConf.OptimizationPolicy.class) final String policyName,
+ @Parameter(JobConf.EnvironmentType.class) final String environmentTypeStr,
+ final ClientRPC clientRPC) {
this.dagDirectory = dagDirectory;
+ this.environmentTypeStr = OptimizerUtils.filterEnvironmentTypeString(environmentTypeStr);
+ this.clientRPC = clientRPC;
try {
optimizationPolicy = (Policy) Class.forName(policyName).newInstance();
@@ -86,7 +97,7 @@ public final class NemoOptimizer implements Optimizer {
}
// Conduct compile-time optimization.
-
+ beforeCompileTimeOptimization(dag, optimizationPolicy);
optimizedDAG = optimizationPolicy.runCompileTimeOptimization(cacheFilteredDag, dagDirectory);
optimizedDAG
.storeJSON(dagDirectory, irDagId + optimizationPolicy.getClass().getSimpleName(),
@@ -114,6 +125,26 @@ public final class NemoOptimizer implements Optimizer {
}
/**
+ * Operations to be done prior to the Compile-Time Optimizations.
+ * TODO #371: This part can be reduced by not using the client RPC and sending the python script to the driver
+ * itself later on.
+ *
+ * @param dag the DAG to process.
+ * @param policy the optimization policy to optimize the DAG with.
+ */
+ private void beforeCompileTimeOptimization(final IRDAG dag, final Policy policy) {
+ if (policy instanceof XGBoostPolicy) {
+ clientRPC.send(ControlMessage.DriverToClientMessage.newBuilder()
+ .setType(ControlMessage.DriverToClientMessageType.LaunchOptimization)
+ .setOptimizationType(ControlMessage.OptimizationType.XGBoost)
+ .setDataCollected(ControlMessage.DataCollectMessage.newBuilder()
+ .setData(dag.irDAGSummary() + this.environmentTypeStr)
+ .build())
+ .build());
+ }
+ }
+
+ /**
* Handle data caching.
* At first, it search the edges having cache ID from the given dag and update them to the given map.
* Then, if some edge of a submitted dag is annotated as "cached" and the data was produced already,
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/OptimizerUtils.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/OptimizerUtils.java
new file mode 100644
index 0000000..1f55096
--- /dev/null
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/OptimizerUtils.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.nemo.compiler.optimizer;
+
+import org.apache.nemo.common.Pair;
+import org.apache.nemo.common.Util;
+import org.apache.nemo.common.exception.InvalidParameterException;
+import org.apache.nemo.common.exception.UnsupportedMethodException;
+
+/**
+ * Utility class for optimizer.
+ */
+public final class OptimizerUtils {
+
+ /**
+ * Private constructor.
+ */
+ private OptimizerUtils() {
+ }
+
+ /**
+ * Restore the formatted string into a pair of vertex/edge id and the execution property.
+ *
+ * @param string the formatted string.
+ * @return a pair of vertex/edge id and the execution property key index.
+ */
+ public static Pair<String, Integer> stringToIdAndEPKeyIndex(final String string) {
+ // Formatted into 9 digits: 0:vertex/edge 1-5:ID 5-9:EP Index.
+ if (string.length() != 9) {
+ throw new InvalidParameterException("The metric data should follow the format of "
+ + "[0]: index indicating vertex/edge, [1-4]: id of the component, and [5-8]: EP Key index. Current: " + string);
+ }
+ final Integer idx = Integer.parseInt(string.substring(0, 1));
+ final Integer numericId = Integer.parseInt(string.substring(1, 5));
+ final String id;
+ if (idx == 1) {
+ id = Util.restoreVertexId(numericId);
+ } else if (idx == 2) {
+ id = Util.restoreEdgeId(numericId);
+ } else {
+ throw new UnsupportedMethodException("The index " + idx + " cannot be categorized into a vertex or an edge");
+ }
+ return Pair.of(id, Integer.parseInt(string.substring(5, 9)));
+ }
+
+ /**
+ * Method to infiltrate keyword-containing string into the enum of Types above.
+ *
+ * @param environmentType the input string.
+ * @return the formatted string corresponding to each type.
+ */
+ public static String filterEnvironmentTypeString(final String environmentType) {
+ if (environmentType.toLowerCase().contains("transient")) {
+ return "transient";
+ } else if (environmentType.toLowerCase().contains("large") && environmentType.toLowerCase().contains("shuffle")) {
+ return "large_shuffle";
+ } else if (environmentType.toLowerCase().contains("disaggregat")) {
+ return "disaggregation";
+ } else if (environmentType.toLowerCase().contains("stream")) {
+ return "streaming";
+ } else if (environmentType.toLowerCase().contains("small")) {
+ return "small_size";
+ } else if (environmentType.toLowerCase().contains("skew")) {
+ return "data_skew";
+ } else {
+ return ""; // Default
+ }
+ }
+}
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/XGBoostPass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/XGBoostPass.java
new file mode 100644
index 0000000..0f91c98
--- /dev/null
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/XGBoostPass.java
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.nemo.compiler.optimizer.pass.compiletime.annotating;
+
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.apache.nemo.common.Pair;
+import org.apache.nemo.common.exception.*;
+import org.apache.nemo.common.ir.IRDAG;
+import org.apache.nemo.common.ir.edge.IREdge;
+import org.apache.nemo.common.ir.executionproperty.EdgeExecutionProperty;
+import org.apache.nemo.common.ir.executionproperty.ExecutionProperty;
+import org.apache.nemo.common.ir.executionproperty.VertexExecutionProperty;
+import org.apache.nemo.common.ir.vertex.IRVertex;
+import org.apache.nemo.compiler.optimizer.OptimizerUtils;
+import org.apache.nemo.runtime.common.metric.MetricUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.Serializable;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+
+/**
+ * Pass for applying XGBoost optimizations.
+ * <p>
+ * 1. The pass first triggers the client to run the XGBoost script, located under the `ml` python package.
+ * 2. The client runs the script, which trains the tree model using the metrics collected before, and constructs
+ * a tree model, which then predicts the 'knobs' that minimizes the JCT based on the weights of the leaves (JCT).
+ * 3. It receives the results, and in which direction each of the knobs should be optimized, and reconstructs the
+ * execution properties in the form that they are tuned.
+ * 4. The newly reconstructed execution properties are injected and the workload runs after the optimization.
+ */
+@Annotates()
+public final class XGBoostPass extends AnnotatingPass {
+ private static final Logger LOG = LoggerFactory.getLogger(XGBoostPass.class.getName());
+
+ private static final BlockingQueue<String> MESSAGE_QUEUE = new LinkedBlockingQueue<>();
+
+ /**
+ * Default constructor.
+ */
+ public XGBoostPass() {
+ super(XGBoostPass.class);
+ }
+
+ @Override
+ public IRDAG apply(final IRDAG dag) {
+ try {
+ final String message = XGBoostPass.takeMessage();
+ LOG.info("Received message from the client: {}", message);
+
+ if (message.isEmpty()) {
+ LOG.info("No optimization included in the message. Returning the original dag.");
+ return dag;
+ } else {
+ ObjectMapper mapper = new ObjectMapper();
+ List<Map<String, String>> listOfMap =
+ mapper.readValue(message, new TypeReference<List<Map<String, String>>>() {
+ });
+ // Formatted into 9 digits: 0:vertex/edge 1-5:ID 5-9:EP Index.
+ listOfMap.stream().filter(m -> m.get("feature").length() == 9).forEach(m -> {
+ final Pair<String, Integer> idAndEPKey = OptimizerUtils.stringToIdAndEPKeyIndex(m.get("feature"));
+ LOG.info("Tuning: {} of {} should be {} than {}",
+ idAndEPKey.right(), idAndEPKey.left(), m.get("val"), m.get("split"));
+ final ExecutionProperty<? extends Serializable> newEP = MetricUtils.keyAndValueToEP(idAndEPKey.right(),
+ Double.valueOf(m.get("split")), Double.valueOf(m.get("val")));
+ try {
+ if (idAndEPKey.left().startsWith("vertex")) {
+ final IRVertex v = dag.getVertexById(idAndEPKey.left());
+ final VertexExecutionProperty<?> originalEP = v.getExecutionProperties().stream()
+ .filter(ep -> ep.getClass().isAssignableFrom(newEP.getClass())).findFirst().orElse(null);
+ v.setProperty((VertexExecutionProperty) newEP);
+ if (!dag.checkIntegrity().isPassed()) {
+ v.setProperty(originalEP);
+ }
+ } else if (idAndEPKey.left().startsWith("edge")) {
+ final IREdge e = dag.getEdgeById(idAndEPKey.left());
+ final EdgeExecutionProperty<?> originalEP = e.getExecutionProperties().stream()
+ .filter(ep -> ep.getClass().isAssignableFrom(newEP.getClass())).findFirst().orElse(null);
+ e.setProperty((EdgeExecutionProperty) newEP);
+ if (!dag.checkIntegrity().isPassed()) {
+ e.setProperty(originalEP);
+ }
+ }
+ } catch (IllegalVertexOperationException | IllegalEdgeOperationException e) {
+ }
+ });
+ }
+ } catch (final InvalidParameterException e) {
+ LOG.warn(e.getMessage());
+ return dag;
+ } catch (final Exception e) {
+ throw new CompileTimeOptimizationException(e);
+ }
+
+ return dag;
+ }
+
+ /**
+ * @param message push the message to the message queue.
+ */
+ public static void pushMessage(final String message) {
+ MESSAGE_QUEUE.add(message);
+ }
+
+ /**
+ * @return the message from the blocking queue.
+ */
+ private static String takeMessage() {
+ try {
+ return MESSAGE_QUEUE.take();
+ } catch (InterruptedException e) {
+ throw new MetricException("Interrupted while waiting for message: " + e);
+ }
+ }
+}
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/policy/XGBoostPolicy.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/policy/XGBoostPolicy.java
new file mode 100644
index 0000000..3960d81
--- /dev/null
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/policy/XGBoostPolicy.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.nemo.compiler.optimizer.policy;
+
+import org.apache.nemo.common.ir.IRDAG;
+import org.apache.nemo.compiler.optimizer.pass.compiletime.annotating.XGBoostPass;
+import org.apache.nemo.compiler.optimizer.pass.compiletime.composite.DefaultCompositePass;
+import org.apache.nemo.compiler.optimizer.pass.runtime.Message;
+
+/**
+ * A policy that enforces values retrieved by an optimization by XGBoost.
+ * Running an application repetitively with this policy automatically improves performance.
+ */
+public final class XGBoostPolicy implements Policy {
+ public static final PolicyBuilder BUILDER =
+ new PolicyBuilder()
+ .registerCompileTimePass(new DefaultCompositePass())
+ .registerCompileTimePass(new XGBoostPass());
+ private final Policy policy;
+
+ /**
+ * Default constructor.
+ */
+ public XGBoostPolicy() {
+ this.policy = BUILDER.build();
+ }
+
+ @Override
+ public IRDAG runCompileTimeOptimization(final IRDAG dag, final String dagDirectory) {
+ return this.policy.runCompileTimeOptimization(dag, dagDirectory);
+ }
+
+ @Override
+ public IRDAG runRunTimeOptimizations(final IRDAG dag, final Message<?> message) {
+ return this.policy.runRunTimeOptimizations(dag, message);
+ }
+}
diff --git a/conf/src/main/java/org/apache/nemo/conf/JobConf.java b/conf/src/main/java/org/apache/nemo/conf/JobConf.java
index 684f48c..83d8957 100644
--- a/conf/src/main/java/org/apache/nemo/conf/JobConf.java
+++ b/conf/src/main/java/org/apache/nemo/conf/JobConf.java
@@ -77,9 +77,16 @@ public final class JobConf extends ConfigurationModuleBuilder {
}
/**
+ * Specifies the type of the environment the workload runs on. (e.g., transient / large_shuffle)
+ */
+ @NamedParameter(doc = "Environment type", short_name = "env", default_value = "")
+ public final class EnvironmentType implements Name<String> {
+ }
+
+ /**
* Address pointing to the DB for saving metrics.
*/
- @NamedParameter(doc = "DB address", short_name = "db_dir", default_value =
+ @NamedParameter(doc = "DB address", short_name = "db_address", default_value =
"jdbc:postgresql://nemo-optimization.cabbufr3evny.us-west-2.rds.amazonaws.com:5432/nemo_optimization")
public final class DBAddress implements Name<String> {
}
diff --git a/ml/nemo_xgboost_optimization.py b/ml/nemo_xgboost_optimization.py
new file mode 100644
index 0000000..f198b5d
--- /dev/null
+++ b/ml/nemo_xgboost_optimization.py
@@ -0,0 +1,385 @@
+#!/usr/bin/python3
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import getopt
+import json
+import sys
+from pathlib import Path
+
+import numpy as np
+import psycopg2 as pg
+import sqlite3 as sq
+import xgboost as xgb
+from sklearn import preprocessing
+
+
+# import matplotlib.pyplot as plt
+
+# ########################################################
+# METHODS
+# ########################################################
+def format_row(duration, inputsize, jvmmemsize, totalmemsize, vertex_properties, edge_properties):
+ duration_in_sec = int(duration) // 1000
+ inputsize_in_10kb = int(inputsize) // 10240 # capable of expressing upto around 20TB with int range
+ jvmmemsize_in_mb = int(jvmmemsize) // 1048576
+ totalmemsize_in_mb = int(totalmemsize) // 1048576
+ return f'{duration_in_sec} 0:{inputsize_in_10kb} 1:{jvmmemsize_in_mb} 2:{totalmemsize_in_mb} {vertex_properties} {edge_properties}'
+
+
+
+# ########################################################
+def load_data_from_db(tablename):
+ conn = None
+
+ try:
+ host = "nemo-optimization.cabbufr3evny.us-west-2.rds.amazonaws.com"
+ dbname = "nemo_optimization"
+ dbuser = "postgres"
+ dbpwd = "fake_password"
+ conn = pg.connect(host=host, dbname=dbname, user=dbuser, password=dbpwd)
+ print("Connected to the PostgreSQL DB.")
+ except:
+ try:
+ sqlite_file = "./optimization_db.sqlite"
+ conn = sq.connect(sqlite_file)
+ print("Connected to the SQLite DB.")
+ except:
+ print("I am unable to connect to the database. Try running the script with `./bin/xgboost_optimization.sh`")
+
+ sql = "SELECT * from " + tablename
+ cur = conn.cursor()
+ try:
+ cur.execute(sql)
+ print("Loaded data from the DB.")
+ except:
+ print("I can't run " + sql)
+
+ rows = cur.fetchall()
+ processed_rows = [format_row(row[1], row[2], row[3], row[4], row[5], row[6]) for row in rows]
+ cur.close()
+ conn.close()
+ return processed_rows
+
+
+# ########################################################
+def write_to_file(filename, rows):
+ f = open(filename, 'w')
+ for row in rows:
+ f.write(row + "\n")
+ f.close()
+
+
+def encode_processed_rows(processed_rows, col_to_id):
+ for i, row in enumerate(processed_rows):
+ arr = row.split()
+ for j, it in enumerate(arr[1:]):
+ k, v = it.split(':')
+ ek = col_to_id[int(k)]
+ arr[j + 1] = f'{ek}:{v}'
+ processed_rows[i] = ' '.join(arr)
+ return processed_rows
+
+
+def decode_rows(rows, id_to_col):
+ for i, row in enumerate(rows):
+ arr = row.split()
+ for j, it in enumerate(arr[1:]):
+ ek, v = it.split(':')
+ k = id_to_col[int(ek)]
+ arr[j + 1] = f'{k}:{v}'
+ rows[i] = ' '.join(arr)
+ return rows
+
+
+# ########################################################
+def stringify_num(num):
+ return str(round(num, 2))
+
+
+def dict_union(d1, d2):
+ for k, v in d2.items():
+ if k in d1:
+ if type(d1[k]) is dict and type(v) is dict: # When same 'feature'
+ d1[k] = dict_union(d1[k], v)
+ else: # When same 'split'
+ d1[k] = d1[k] + v
+ elif type(v) is dict: # When no initial data
+ d1[k] = v
+ else: # k = split, v = diff. include if it does not violate.
+ if v > 0 > max(d1.values()) and k < max(d1.keys()): # If no positive values yet
+ d1[k] = v
+ elif v > max(d1.values()) > 0: # Update if greater value
+ max_key = max(d1, key=lambda key: d1[key])
+ del d1[max_key]
+ d1[k] = v
+ elif v < 0 < min(d1.values()) and min(d1.keys()) < k: # If no negative values yet
+ d1[k] = v
+ elif v < min(d1.values()) < 0: # Update if smaller value
+ min_key = min(d1, key=lambda key: d1[key])
+ del d1[min_key]
+ d1[k] = v
+ return d1
+
+
+# ########################################################
+class Tree:
+ root = None
+ idx_to_node = {}
+
+ def append_to_dict_if_not_exists(self, idx, node):
+ if idx not in self.idx_to_node:
+ self.idx_to_node[idx] = node
+
+ def addNode(self, index, feature_id, split, yes, no, missing, value):
+ n = None
+ if self.root == None:
+ self.root = Node(None)
+ n = self.root
+ self.append_to_dict_if_not_exists(index, n)
+ else:
+ n = self.idx_to_node[index]
+
+ self.append_to_dict_if_not_exists(yes, Node(n))
+ self.append_to_dict_if_not_exists(no, Node(n))
+ self.append_to_dict_if_not_exists(missing, Node(n))
+ n.addAttributes(index, feature_id, split, yes, no, missing, value, self.idx_to_node)
+
+ def importanceDict(self):
+ return self.root.importanceDict()
+
+ def __str__(self):
+ return json.dumps(json.loads(str(self.root)), indent=4)
+
+
+class Node:
+ parent = None
+ index = None
+
+ feature = None
+ split = None
+ left = None
+ right = None
+ missing = None
+
+ value = None
+
+ def __init__(self, parent):
+ self.parent = parent
+
+ def addAttributes(self, index, feature_id, split, yes, no, missing, value, idx_to_node):
+ self.index = index
+ if feature_id == 'Leaf':
+ self.value = value
+ else:
+ self.feature = feature_id
+ self.split = split
+ self.left = idx_to_node[yes]
+ self.right = idx_to_node[no]
+ self.missing = idx_to_node[missing]
+
+ def isLeaf(self):
+ return self.value != None
+
+ def isRoot(self):
+ return self.parent == None
+
+ def getIndex(self):
+ return self.index
+
+ def getLeft(self):
+ return self.left
+
+ def getRight(self):
+ return self.right
+
+ def getMissing(self):
+ return self.missing
+
+ def getApprox(self):
+ if self.isLeaf():
+ return self.value
+ else:
+ lapprox = self.left.getApprox()
+ rapprox = self.right.getApprox()
+ if rapprox != 0 and abs(lapprox / rapprox) < 0.04: # smaller than 4% then ignore
+ return rapprox
+ elif lapprox != 0 and abs(rapprox / lapprox) < 0.04:
+ return lapprox
+ else:
+ return (lapprox + rapprox) / 2
+
+ def getDiff(self):
+ lapprox = self.left.getApprox()
+ rapprox = self.right.getApprox()
+ if (rapprox != 0 and abs(lapprox / rapprox) < 0.04) or (lapprox != 0 and abs(rapprox / lapprox) < 0.04):
+ return 0 # ignore
+ return lapprox - rapprox
+
+ def importanceDict(self):
+ if self.isLeaf():
+ return {}
+ else:
+ d = {}
+ d[self.feature] = {self.split: self.getDiff()}
+ return dict_union(d, dict_union(self.left.importanceDict(), self.right.importanceDict()))
+
+ def __str__(self):
+ if self.isLeaf():
+ return f'{stringify_num(self.value)}'
+ else:
+ left = str(self.left) if self.left.isLeaf() else json.loads(str(self.left))
+ right = str(self.right) if self.right.isLeaf() else json.loads(str(self.right))
+ return json.dumps({self.index: f'{self.feature}' + '{' + stringify_num(self.getApprox()) + ',' + stringify_num(
+ self.getDiff()) + '}', 'L' + self.left.getIndex(): left, 'R' + self.right.getIndex(): right})
+
+
+# ########################################################
+# MAIN FUNCTION
+# ########################################################
+try:
+ opts, args = getopt.getopt(sys.argv[1:], "ht:m:i:", ["tablename=", "memsize=", "inputsize="])
+except getopt.GetoptError:
+ print('nemo_xgboost_optimization.py -t <tablename>')
+ sys.exit(2)
+for opt, arg in opts:
+ if opt == '-h':
+ print('nemo_xgboost_optimization.py -t <tablename>')
+ sys.exit()
+ elif opt in ("-t", "--tablename"):
+ tablename = arg
+ elif opt in ("-m", "--memsize"):
+ memsize = arg
+ elif opt in ("-i", "--inputsize"):
+ inputsize = arg
+
+modelname = tablename + "_bst.model"
+processed_rows = load_data_from_db(tablename)
+# write_to_file('process_test', processed_rows)
+
+## Make Dictionary
+col = []
+for row in processed_rows:
+ arr = row.split()
+ for it in arr[1:]:
+ k, v = it.split(':')
+ col.append(int(k))
+le = preprocessing.LabelEncoder()
+ids = le.fit_transform(col)
+col_to_id = dict(zip(col, ids))
+id_to_col = dict(zip(ids, col))
+
+## PREPROCESSING DATA FOR TAINING
+encoded_rows = encode_processed_rows(processed_rows, col_to_id)
+write_to_file('nemo_optimization.out', encoded_rows)
+# write_to_file('decode_test', decode_rows(encoded_rows, id_to_col))
+ddata = xgb.DMatrix('nemo_optimization.out')
+
+avg_20_duration = np.mean(ddata.get_label()[:20])
+print("average job duration: ", avg_20_duration)
+allowance = avg_20_duration // 25 # 4%
+
+row_size = len(processed_rows)
+print("total_rows: ", row_size)
+
+## TRAIN THE MODEL (REGRESSION)
+dtrain = ddata.slice([i for i in range(0, row_size) if i % 6 != 5]) # mod is not 5
+print("train_rows: ", dtrain.num_row())
+dtest = ddata.slice([i for i in range(0, row_size) if i % 6 == 5]) # mod is 5
+print("test_rows: ", dtest.num_row())
+labels = dtest.get_label()
+
+## Load existing booster, if it exists
+bst_opt = xgb.Booster(model_file=modelname) if Path(modelname).is_file() else None
+preds_opt = bst_opt.predict(dtest) if bst_opt is not None else None
+error_opt = (sum(1 for i in range(len(preds_opt)) if abs(preds_opt[i] - labels[i]) > allowance) / float(
+ len(preds_opt))) if preds_opt is not None else 1
+print('opt_error=%f' % error_opt)
+min_error = error_opt
+
+learning_rates = [0.1, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9]
+for lr in learning_rates:
+ param = {'max_depth': 6, 'eta': lr, 'verbosity': 0, 'objective': 'reg:linear'}
+
+ watchlist = [(dtest, 'eval'), (dtrain, 'train')]
+ num_round = row_size // 10
+ bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=5)
+
+ preds = bst.predict(dtest)
+ error = (sum(1 for i in range(len(preds)) if abs(preds[i] - labels[i]) > allowance) / float(len(preds))) if len(
+ preds) > 0 else 1.0
+ print('error=%f' % error)
+
+ ## Better booster
+ if error <= error_opt:
+ bst_opt = bst
+ bst.save_model(modelname)
+ min_error = error
+
+print('minimum error=%f' % min_error)
+
+## Let's now use bst_opt
+## Check out the histogram by uncommenting the lines below
+# fscore = bst_opt.get_fscore()
+# sorted_fscore = sorted(fscore.items(), key=lambda kv: kv[1])
+# for i in range(len(sorted_fscore)):
+# print("\nSplit Value Histogram:")
+# feature = sorted_fscore.pop()[0]
+# print(feature, "=", id_to_col[int(feature[1:])])
+# hg = bst_opt.get_split_value_histogram(feature)
+# print(hg)
+
+df = bst_opt.trees_to_dataframe()
+# print("Trees to dataframe")
+# print(df)
+
+trees = {}
+for index, row in df.iterrows():
+ if row['Tree'] not in trees: # Tree number = index
+ trees[row['Tree']] = Tree()
+
+ translated_feature = id_to_col[int(row['Feature'][1:])] if row['Feature'].startswith('f') else row['Feature']
+ # print(translated_feature)
+ trees[row['Tree']].addNode(row['ID'], translated_feature, row['Split'], row['Yes'], row['No'], row['Missing'],
+ row['Gain'])
+
+results = {}
+print("\nGenerated Trees:")
+for t in trees.values():
+ results = dict_union(results, t.importanceDict())
+ # print(t)
+
+print("\nImportanceDict")
+print(json.dumps(results, indent=2))
+
+print("\nSummary")
+resultsJson = []
+for k, v in results.items():
+ for kk, vv in v.items():
+ resultsJson.append({'feature': k, 'split': kk, 'val': vv})
+ how = 'greater' if vv > 0 else 'smaller'
+ restring = f'{k} should be {how} than {kk}'
+ print(restring)
+
+with open("results.out", "w") as file:
+ file.write(json.dumps(resultsJson, indent=2))
+
+# Visualize tree
+# xgb.plot_tree(bst_opt)
+# plt.show()
diff --git a/ml/requirements.txt b/ml/requirements.txt
new file mode 100644
index 0000000..1729945
--- /dev/null
+++ b/ml/requirements.txt
@@ -0,0 +1,9 @@
+graphviz==0.10.1
+matplotlib==3.0.3
+numpy==1.16.2
+psycopg2==2.7.7
+psycopg2-binary==2.7.7
+pygraphviz==1.5
+sklearn==0.0
+xgboost==0.82
+pandas==0.24.1
diff --git a/pom.xml b/pom.xml
index 7c1615a..e46e029 100644
--- a/pom.xml
+++ b/pom.xml
@@ -328,6 +328,10 @@ under the License.
<exclude>**/*.json</exclude>
<exclude>**/.editorconfig</exclude>
<exclude>**/config.gypi</exclude>
+ <!-- ML -->
+ <exclude>**/*.txt</exclude>
+ <exclude>**/*.out</exclude>
+ <exclude>venv/**</exclude>
<!-- EditorConfig -->
<exclude>.editorconfig</exclude>
<!-- formatter.xml -->
diff --git a/runtime/common/pom.xml b/runtime/common/pom.xml
index 6f8c68e..00a458f 100644
--- a/runtime/common/pom.xml
+++ b/runtime/common/pom.xml
@@ -79,5 +79,12 @@ under the License.
<artifactId>jackson-databind</artifactId>
<version>${jackson.version}</version>
</dependency>
+ <!--For serialization & deserialization of Beam components-->
+ <dependency>
+ <groupId>org.apache.nemo</groupId>
+ <artifactId>nemo-compiler-frontend-beam</artifactId>
+ <version>${project.version}</version>
+ <scope>test</scope>
+ </dependency>
</dependencies>
</project>
diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/ClientRPC.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/message/ClientRPC.java
similarity index 99%
rename from runtime/master/src/main/java/org/apache/nemo/runtime/master/ClientRPC.java
rename to runtime/common/src/main/java/org/apache/nemo/runtime/common/message/ClientRPC.java
index 3405ca9..4317bb1 100644
--- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/ClientRPC.java
+++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/message/ClientRPC.java
@@ -16,7 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
-package org.apache.nemo.runtime.master;
+package org.apache.nemo.runtime.common.message;
import com.google.protobuf.InvalidProtocolBufferException;
import org.apache.nemo.conf.JobConf;
diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/metric/MetricUtils.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/metric/MetricUtils.java
index 0f22f26..8d458ae 100644
--- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/metric/MetricUtils.java
+++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/metric/MetricUtils.java
@@ -22,6 +22,7 @@ package org.apache.nemo.runtime.common.metric;
import com.google.common.collect.HashBiMap;
import org.apache.commons.lang3.SerializationUtils;
import org.apache.nemo.common.Pair;
+import org.apache.nemo.common.Util;
import org.apache.nemo.common.coder.DecoderFactory;
import org.apache.nemo.common.coder.EncoderFactory;
import org.apache.nemo.common.exception.MetricException;
@@ -31,16 +32,18 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.ByteArrayInputStream;
-import java.io.IOException;
-import java.nio.file.Files;
-import java.nio.file.Path;
-import java.nio.file.Paths;
+import java.io.Serializable;
+import java.lang.reflect.Method;
import java.sql.*;
+import java.util.Arrays;
+import java.util.Optional;
import java.util.concurrent.CountDownLatch;
-import java.util.stream.Stream;
+import java.util.function.Supplier;
+import java.util.stream.IntStream;
/**
* Utility class for metrics.
+ * TODO #372: This class should later be refactored into a separate metric package.
*/
public final class MetricUtils {
private static final Logger LOG = LoggerFactory.getLogger(MetricUtils.class.getName());
@@ -49,20 +52,24 @@ public final class MetricUtils {
private static final CountDownLatch MUST_UPDATE_EP_KEY_METADATA = new CountDownLatch(1);
private static final CountDownLatch MUST_UPDATE_EP_METADATA = new CountDownLatch(1);
- private static final Pair<HashBiMap<Integer, Class<? extends ExecutionProperty>>,
- HashBiMap<Pair<Integer, Integer>, ExecutionProperty<?>>> METADATA = loadMetaData();
- // BiMap of (1) INDEX and (2) the Execution Property class
- private static final HashBiMap<Integer, Class<? extends ExecutionProperty>>
- EP_KEY_METADATA = METADATA.left();
- // BiMap of (1) the Execution Property class INDEX and the value INDEX pair and (2) the Execution Property.
- private static final HashBiMap<Pair<Integer, Integer>, ExecutionProperty<?>>
- EP_METADATA = METADATA.right();
+ // BiMap of (1) INDEX and (2) the Execution Property class and the value type class.
+ static final HashBiMap<Integer, Pair<Class<? extends ExecutionProperty>, Class<? extends Serializable>>>
+ EP_KEY_METADATA = HashBiMap.create();
+ // BiMap of (1) the Execution Property class INDEX and the value INDEX pair and (2) the Execution Property value.
+ private static final HashBiMap<Pair<Integer, Integer>, ExecutionProperty<? extends Serializable>>
+ EP_METADATA = HashBiMap.create();
- private static final int VERTEX = 1;
- private static final int EDGE = 2;
+ static {
+ try {
+ Class.forName("org.postgresql.Driver");
+ } catch (ClassNotFoundException e) {
+ throw new MetricException("PostgreSQL Driver not found: " + e);
+ }
+ loadMetaData();
+ }
public static final String SQLITE_DB_NAME =
- "jdbc:sqlite:" + MetricUtils.fetchProjectRootPath() + "/optimization_db.sqlite3";
+ "jdbc:sqlite:" + Util.fetchProjectRootPath() + "/optimization_db.sqlite3";
public static final String POSTGRESQL_METADATA_DB_NAME =
"jdbc:postgresql://nemo-optimization.cabbufr3evny.us-west-2.rds.amazonaws.com:5432/nemo_optimization";
private static final String METADATA_TABLE_NAME = "nemo_optimization_meta";
@@ -78,8 +85,7 @@ public final class MetricUtils {
*
* @return the loaded BiMaps, or initialized ones.
*/
- private static Pair<HashBiMap<Integer, Class<? extends ExecutionProperty>>,
- HashBiMap<Pair<Integer, Integer>, ExecutionProperty<?>>> loadMetaData() {
+ private static void loadMetaData() {
try (Connection c = DriverManager.getConnection(MetricUtils.POSTGRESQL_METADATA_DB_NAME,
"postgres", "fake_password")) {
try (Statement statement = c.createStatement()) {
@@ -87,43 +93,33 @@ public final class MetricUtils {
statement.executeUpdate(
"CREATE TABLE IF NOT EXISTS " + METADATA_TABLE_NAME
- + " (key TEXT NOT NULL UNIQUE, data BYTEA NOT NULL);");
+ + " (type TEXT NOT NULL, key INT NOT NULL UNIQUE, value BYTEA NOT NULL, "
+ + "created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP);");
final ResultSet rsl = statement.executeQuery(
- "SELECT * FROM " + METADATA_TABLE_NAME + " WHERE key='EP_KEY_METADATA';");
- LOG.info("Metadata can be loaded.");
- if (rsl.next()) {
- final HashBiMap<Integer, Class<? extends ExecutionProperty>> indexEpKeyBiMap =
- SerializationUtils.deserialize(rsl.getBytes("Data"));
- rsl.close();
-
- final ResultSet rsr = statement.executeQuery(
- "SELECT * FROM " + METADATA_TABLE_NAME + " WHERE key='EP_METADATA';");
- if (rsr.next()) {
- final HashBiMap<Pair<Integer, Integer>, ExecutionProperty<?>> indexEpBiMap =
- SerializationUtils.deserialize(rsr.getBytes("Data"));
- rsr.close();
-
- METADATA_LOADED.countDown();
- LOG.info("Metadata successfully loaded from DB.");
- return Pair.of(indexEpKeyBiMap, indexEpBiMap);
- } else {
- METADATA_LOADED.countDown();
- LOG.info("No initial metadata for EP.");
- return Pair.of(indexEpKeyBiMap, HashBiMap.create());
- }
- } else {
- METADATA_LOADED.countDown();
- LOG.info("No initial metadata.");
- return Pair.of(HashBiMap.create(), HashBiMap.create());
+ "SELECT * FROM " + METADATA_TABLE_NAME + " WHERE type='EP_KEY_METADATA';");
+ LOG.info("Metadata can be successfully loaded.");
+ while (rsl.next()) {
+ EP_KEY_METADATA.put(rsl.getInt("key"),
+ SerializationUtils.deserialize(rsl.getBytes("value")));
+ }
+ rsl.close();
+
+ final ResultSet rsr = statement.executeQuery(
+ "SELECT * FROM " + METADATA_TABLE_NAME + " WHERE type='EP_METADATA';");
+ while (rsr.next()) {
+ final Integer l = rsr.getInt("key");
+ EP_METADATA.put(Pair.of(l / 10000, 1 % 10000),
+ SerializationUtils.deserialize(rsr.getBytes("value")));
}
+ rsr.close();
+ METADATA_LOADED.countDown();
+ LOG.info("Metadata successfully loaded from DB.");
} catch (Exception e) {
LOG.warn("Loading metadata from DB failed: ", e);
- return Pair.of(HashBiMap.create(), HashBiMap.create());
}
} catch (Exception e) {
LOG.warn("Loading metadata from DB failed : ", e);
- return Pair.of(HashBiMap.create(), HashBiMap.create());
}
}
@@ -150,25 +146,25 @@ public final class MetricUtils {
statement.setQueryTimeout(30); // set timeout to 30 sec.
if (MUST_UPDATE_EP_KEY_METADATA.getCount() == 0) {
- try (PreparedStatement pstmt = c.prepareStatement(
- "INSERT INTO " + METADATA_TABLE_NAME + " (key, data) "
- + "VALUES ('EP_KEY_METADATA', ?) ON CONFLICT (key) DO UPDATE SET data = excluded.data;")) {
- pstmt.setBinaryStream(1,
- new ByteArrayInputStream(SerializationUtils.serialize(EP_KEY_METADATA)));
- pstmt.executeUpdate();
- LOG.info("EP Key Metadata saved to DB.");
- }
+ EP_KEY_METADATA.forEach((l, r) -> {
+ try {
+ insertOrUpdateMetadata(c, "EP_KEY_METADATA", l, r);
+ } catch (SQLException e) {
+ LOG.warn("Saving of Metadata to DB failed: ", e);
+ }
+ });
+ LOG.info("EP Key Metadata saved to DB.");
}
if (MUST_UPDATE_EP_METADATA.getCount() == 0) {
- try (PreparedStatement pstmt =
- c.prepareStatement("INSERT INTO " + METADATA_TABLE_NAME + "(key, data) "
- + "VALUES ('EP_METADATA', ?) ON CONFLICT (key) DO UPDATE SET data = excluded.data;")) {
- pstmt.setBinaryStream(1,
- new ByteArrayInputStream(SerializationUtils.serialize(EP_METADATA)));
- pstmt.executeUpdate();
- LOG.info("EP Metadata saved to DB.");
- }
+ EP_METADATA.forEach((l, r) -> {
+ try {
+ insertOrUpdateMetadata(c, "EP_METADATA", l.left() * 10000 + l.right(), r);
+ } catch (SQLException e) {
+ LOG.warn("Saving of Metadata to DB failed: ", e);
+ }
+ });
+ LOG.info("EP Metadata saved to DB.");
}
}
} catch (SQLException e) {
@@ -177,23 +173,42 @@ public final class MetricUtils {
}
/**
+ * Utility method to save key, value to the metadata table.
+ *
+ * @param c the connection to the DB.
+ * @param type the key to write to the DB metadata table.
+ * @param key the key to write to the DB metadata table (integer).
+ * @param value the value to write to the DB metadata table (object).
+ * @throws SQLException SQLException on the way.
+ */
+ private static void insertOrUpdateMetadata(final Connection c, final String type,
+ final Integer key, final Serializable value) throws SQLException {
+ try (PreparedStatement pstmt = c.prepareStatement(
+ "INSERT INTO " + METADATA_TABLE_NAME + " (type, key, value) "
+ + "VALUES ('" + type + "', " + key + ", ?) ON CONFLICT (key) DO UPDATE SET value = excluded.value;")) {
+ pstmt.setBinaryStream(1, new ByteArrayInputStream(SerializationUtils.serialize(value)));
+ pstmt.executeUpdate();
+ }
+ }
+
+ /**
* Stringify execution properties of an IR DAG.
*
* @param irdag IR DAG to observe.
* @return the pair of stringified execution properties. Left is for vertices, right is for edges.
*/
- static Pair<String, String> stringifyIRDAGProperties(final IRDAG irdag) {
+ public static Pair<String, String> stringifyIRDAGProperties(final IRDAG irdag) {
final StringBuilder vStringBuilder = new StringBuilder();
final StringBuilder eStringBuilder = new StringBuilder();
irdag.getVertices().forEach(v ->
v.getExecutionProperties().forEachProperties(ep ->
- epFormatter(vStringBuilder, VERTEX, v.getNumericId(), ep)));
+ epFormatter(vStringBuilder, 1, v.getNumericId(), ep)));
irdag.getVertices().forEach(v ->
irdag.getIncomingEdgesOf(v).forEach(e ->
e.getExecutionProperties().forEachProperties(ep ->
- epFormatter(eStringBuilder, EDGE, e.getNumericId(), ep))));
+ epFormatter(eStringBuilder, 2, e.getNumericId(), ep))));
// Update the metric metadata if new execution property key / values have been discovered and updates are required.
updateMetaData();
@@ -210,17 +225,13 @@ public final class MetricUtils {
*/
private static void epFormatter(final StringBuilder builder, final int idx,
final Integer numericId, final ExecutionProperty<?> ep) {
+ // Formatted into 9 digits: 0:vertex/edge 1-5:ID 5-9:EP Index.
builder.append(idx);
- builder.append(numericId);
- builder.append("0");
- final Integer epKeyIndex = EP_KEY_METADATA.inverse().computeIfAbsent(ep.getClass(), epClass -> {
- // Update the metadata if new EP key has been discovered.
- LOG.info("New EP Key Index: {} for {}", EP_KEY_METADATA.size() + 1, epClass.getSimpleName());
- MUST_UPDATE_EP_KEY_METADATA.countDown();
- return EP_KEY_METADATA.size() + 1;
- });
- builder.append(epKeyIndex);
+ builder.append(String.format("%04d", numericId));
+ final Integer epKeyIndex = getEpKeyIndex(ep);
+ builder.append(String.format("%04d", epKeyIndex));
+ // Format value to an index.
builder.append(":");
final Integer epIndex = valueToIndex(epKeyIndex, ep);
builder.append(epIndex);
@@ -228,6 +239,74 @@ public final class MetricUtils {
}
/**
+ * Get the EP Key index from the metadata.
+ *
+ * @param ep the EP to retrieve the Key index of.
+ * @return the Key index.
+ */
+ static Integer getEpKeyIndex(final ExecutionProperty<?> ep) {
+ return EP_KEY_METADATA.inverse()
+ .computeIfAbsent(Pair.of(ep.getClass(), getParameterType(ep.getClass(), ep.getValue().getClass())),
+ epClassPair -> {
+ final Integer idx = EP_KEY_METADATA.keySet().stream().mapToInt(i -> i).max().orElse(0) + 1;
+ LOG.info("New EP Key Index: {} for {}", idx, epClassPair.left().getSimpleName());
+ // Update the metadata if new EP key has been discovered.
+ MUST_UPDATE_EP_KEY_METADATA.countDown();
+ return idx;
+ });
+ }
+
+ /**
+ * Recursive method for getting the parameter type of the execution property.
+ * This can be used, for example, to get DecoderFactory, instead of BeamDecoderFactory.
+ *
+ * @param epClass execution property class to observe.
+ * @param valueClass the value class of the execution property.
+ * @return the parameter type.
+ */
+ private static Class<? extends Serializable> getParameterType(final Class<? extends ExecutionProperty> epClass,
+ final Class<? extends Serializable> valueClass) {
+ if (!getMethodFor(epClass, "of", valueClass.getSuperclass()).isPresent()
+ || !(Serializable.class.isAssignableFrom(valueClass.getSuperclass()))) {
+ final Class<? extends Serializable> candidate = Arrays.stream(valueClass.getInterfaces())
+ .filter(vc -> Serializable.class.isAssignableFrom(vc) && getMethodFor(epClass, "of", vc).isPresent())
+ .map(vc -> getParameterType(epClass, ((Class<? extends Serializable>) vc))).findFirst().orElse(null);
+ return candidate == null ? valueClass : candidate;
+ } else {
+ return getParameterType(epClass, ((Class<? extends Serializable>) valueClass.getSuperclass()));
+ }
+ }
+
+ /**
+ * Utility method to getting an optional method called 'name' for the class.
+ *
+ * @param clazz class to get the method of.
+ * @param name the name of the method.
+ * @param valueTypes the value types of the method.
+ * @return optional of the method. It returns Optional.empty() if the method could not be found.
+ */
+ public static Optional<Method> getMethodFor(final Class<? extends ExecutionProperty> clazz,
+ final String name, final Class<?>... valueTypes) {
+ try {
+ final Method mthd = clazz.getMethod(name, valueTypes);
+ return Optional.of(mthd);
+ } catch (final Exception e) {
+ return Optional.empty();
+ }
+ }
+
+ /**
+ * Inverse method of the #getEpKeyIndex method.
+ *
+ * @param index the index of the EP Key.
+ * @return the class of the execution property (EP), as well as the type of the value of the EP.
+ */
+ private static Pair<Class<? extends ExecutionProperty>, Class<? extends Serializable>> getEpPairFromKeyIndex(
+ final Integer index) {
+ return EP_KEY_METADATA.get(index);
+ }
+
+ /**
* Helper method to convert Execution Property value objects to an integer index.
* It updates the metadata for the metrics if new EP values are discovered.
*
@@ -235,7 +314,7 @@ public final class MetricUtils {
* @param ep the execution property containing the value.
* @return the converted value index.
*/
- private static Integer valueToIndex(final Integer epKeyIndex, final ExecutionProperty<?> ep) {
+ static Integer valueToIndex(final Integer epKeyIndex, final ExecutionProperty<?> ep) {
final Object o = ep.getValue();
if (o instanceof Enum) {
@@ -245,10 +324,10 @@ public final class MetricUtils {
} else if (o instanceof Boolean) {
return ((Boolean) o) ? 1 : 0;
} else {
- final ExecutionProperty<?> ep1;
+ final ExecutionProperty<? extends Serializable> ep1;
if (o instanceof EncoderFactory || o instanceof DecoderFactory) {
ep1 = EP_METADATA.values().stream()
- .filter(ep2 -> ep2.getValue().toString().equals(o.toString()))
+ .filter(ep2 -> ep2.getValue().toString().equals(o.toString()) || ep2.getValue().equals(o))
.findFirst().orElse(null);
} else {
ep1 = EP_METADATA.values().stream()
@@ -259,9 +338,9 @@ public final class MetricUtils {
if (ep1 != null) {
return EP_METADATA.inverse().get(ep1).right();
} else {
- final Integer valueIndex = Math.toIntExact(EP_METADATA.keySet().stream()
+ final Integer valueIndex = EP_METADATA.keySet().stream()
.filter(pair -> pair.left().equals(epKeyIndex))
- .count()) + 1;
+ .mapToInt(Pair::right).max().orElse(0) + 1;
// Update the metadata if new EP value has been discovered.
EP_METADATA.put(Pair.of(epKeyIndex, valueIndex), ep);
LOG.info("New EP Index: ({}, {}) for {}", epKeyIndex, valueIndex, ep);
@@ -272,29 +351,99 @@ public final class MetricUtils {
}
/**
- * Finds the project root path.
+ * Helper method to do the opposite of the #valueToIndex method.
+ * It receives the split, and the direction of the tweak value (which show the target index value),
+ * and returns the actual value which the execution property uses.
*
+ * @param split the split value, from which to start from.
+ * @param tweak the tweak value, to which we should tweak the split value.
+ * @param epKeyIndex the EP Key index to retrieve information from.
* @return the project root path.
*/
- private static String fetchProjectRootPath() {
- return recursivelyFindLicense(Paths.get(System.getProperty("user.dir")));
+ static Serializable indexToValue(final Double split, final Double tweak, final Integer epKeyIndex) {
+ final Class<? extends Serializable> targetObjectClass = getEpPairFromKeyIndex(epKeyIndex).right();
+ final boolean splitIsInteger = split.compareTo((double) split.intValue()) == 0;
+ final Pair<Integer, Integer> splitIntVal = splitIsInteger
+ ? Pair.of(split.intValue() - 1, split.intValue() + 1)
+ : Pair.of(split.intValue(), split.intValue() + 1);
+
+ if (targetObjectClass.isEnum()) {
+ final int ordinal;
+ if (split < 0) {
+ ordinal = 0;
+ } else {
+ final int maxOrdinal = targetObjectClass.getFields().length - 1;
+ final int left = splitIntVal.left() <= 0 ? 0 : splitIntVal.left();
+ final int right = splitIntVal.right() >= maxOrdinal ? maxOrdinal : splitIntVal.right();
+ ordinal = tweak < 0 ? left : right;
+ }
+ LOG.info("Translated: {} into ENUM with ordinal {}", split, ordinal);
+ return targetObjectClass.getEnumConstants()[ordinal];
+ } else if (targetObjectClass.isAssignableFrom(Integer.class)) {
+ final Double val = split + tweak + 0.5;
+ final Integer res = val.intValue();
+ LOG.info("Translated: {} into INTEGER of {}", split, res);
+ return res;
+ } else if (targetObjectClass.isAssignableFrom(Boolean.class)) {
+ final Boolean res;
+ if (split < 0) {
+ res = false;
+ } else if (split > 1) {
+ res = true;
+ } else {
+ final Boolean left = splitIntVal.left() >= 1; // false by default, true if >= 1
+ final Boolean right = splitIntVal.right() > 0; // true by default, false if <= 0
+ res = tweak < 0 ? left : right;
+ }
+ LOG.info("Translated: {} into BOOLEAN of {}", split, res);
+ return res;
+ } else {
+ final Supplier<IntStream> valueCandidates = () -> EP_METADATA.keySet().stream()
+ .filter(p -> p.left().equals(epKeyIndex))
+ .mapToInt(Pair::right);
+ final Integer left = valueCandidates.get()
+ .filter(n -> n < split)
+ .map(n -> -n).sorted().map(n -> -n) // maximum among smaller values
+ .findFirst().orElse(valueCandidates.get().min().getAsInt());
+ final Integer right = valueCandidates.get()
+ .filter(n -> n > split)
+ .sorted() // minimum among larger values
+ .findFirst().orElse(valueCandidates.get().max().getAsInt());
+ final Integer targetValue = tweak < 0 ? left : right;
+ final Serializable res = EP_METADATA.get(Pair.of(epKeyIndex, targetValue)).getValue();
+ LOG.info("Translated: {} into VALUE of {}", split, res);
+ return res;
+ }
}
/**
- * Helper method to recursively find the LICENSE file.
+ * Receives the pair of execution property and value classes, and returns the optimized value of the EP.
*
- * @param path the path to search for.
- * @return the path containing the LICENSE file.
+ * @param epKeyIndex the EP Key index to retrieve the new EP from.
+ * @param split the split point.
+ * @param tweak the direction in which to tweak the execution property value.
+ * @return The execution property constructed from the key index and the split value.
*/
- private static String recursivelyFindLicense(final Path path) {
- try (Stream stream = Files.find(path, 1, (p, attributes) -> p.endsWith("LICENSE"))) {
- if (stream.count() > 0) {
- return path.toAbsolutePath().toString();
- } else {
- return recursivelyFindLicense(path.getParent());
- }
- } catch (IOException e) {
+ public static ExecutionProperty<? extends Serializable> keyAndValueToEP(
+ final Integer epKeyIndex,
+ final Double split,
+ final Double tweak) {
+
+ final Serializable value = indexToValue(split, tweak, epKeyIndex);
+ final Class<? extends ExecutionProperty> epClass = getEpPairFromKeyIndex(epKeyIndex).left();
+
+ final ExecutionProperty<? extends Serializable> ep;
+ try {
+ final Method staticConstructor = getMethodFor(epClass, "of", getParameterType(epClass, value.getClass()))
+ .orElseThrow(NoSuchMethodException::new);
+ ep = (ExecutionProperty<? extends Serializable>) staticConstructor.invoke(null, value);
+ } catch (final NoSuchMethodException e) {
+ throw new MetricException("Class " + epClass.getName()
+ + " does not have a static method exposing the constructor 'of' with value type " + value.getClass().getName()
+ + ": " + e);
+ } catch (final Exception e) {
throw new MetricException(e);
}
+ return ep;
}
}
diff --git a/runtime/common/src/main/proto/ControlMessage.proto b/runtime/common/src/main/proto/ControlMessage.proto
index b7f257e..97e30fb 100644
--- a/runtime/common/src/main/proto/ControlMessage.proto
+++ b/runtime/common/src/main/proto/ControlMessage.proto
@@ -23,14 +23,17 @@ package protobuf;
option java_package = "org.apache.nemo.runtime.common.comm";
option java_outer_classname = "ControlMessage";
+// Messages from client to driver
enum ClientToDriverMessageType {
LaunchDAG = 0;
DriverShutdown = 1;
+ Notification = 2;
}
message ClientToDriverMessage {
required ClientToDriverMessageType type = 1;
optional LaunchDAGMessage launchDAG = 2;
+ optional NotificationMessage message = 3;
}
message LaunchDAGMessage {
@@ -38,20 +41,32 @@ message LaunchDAGMessage {
optional bytes broadcastVars = 2;
}
-message DataCollectMessage {
- required string data = 1;
+enum OptimizationType {
+ XGBoost = 0;
}
+message NotificationMessage {
+ required OptimizationType optimizationType = 1;
+ required string data = 2;
+}
+
+// Messages from driver to client
enum DriverToClientMessageType {
DriverStarted = 0;
DriverReady = 1;
DataCollected = 2;
ExecutionDone = 3;
+ LaunchOptimization = 4;
}
message DriverToClientMessage {
required DriverToClientMessageType type = 1;
optional DataCollectMessage dataCollected = 2;
+ optional OptimizationType optimizationType = 3;
+}
+
+message DataCollectMessage {
+ required string data = 1;
}
enum MessageType {
@@ -158,7 +173,6 @@ message MetricMsg {
}
// Messages between Executors
-
enum ByteTransferDataDirection {
INITIATOR_SENDS_DATA = 0;
INITIATOR_RECEIVES_DATA = 1;
diff --git a/runtime/common/src/test/java/org/apache/nemo/runtime/common/metric/MetricUtilsTest.java b/runtime/common/src/test/java/org/apache/nemo/runtime/common/metric/MetricUtilsTest.java
new file mode 100644
index 0000000..d1da5f0
--- /dev/null
+++ b/runtime/common/src/test/java/org/apache/nemo/runtime/common/metric/MetricUtilsTest.java
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.nemo.runtime.common.metric;
+
+import org.apache.nemo.common.coder.DecoderFactory;
+import org.apache.nemo.common.coder.EncoderFactory;
+import org.apache.nemo.common.ir.edge.executionproperty.DataFlowProperty;
+import org.apache.nemo.common.ir.edge.executionproperty.DecoderProperty;
+import org.apache.nemo.common.ir.edge.executionproperty.EncoderProperty;
+import org.apache.nemo.common.ir.executionproperty.ExecutionProperty;
+import org.apache.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
+import org.apache.nemo.common.ir.vertex.executionproperty.ResourceSlotProperty;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.Serializable;
+
+public class MetricUtilsTest {
+
+ @Test
+ public void testEnumIndexAndValue() {
+ final DataFlowProperty.Value pull = DataFlowProperty.Value.Pull;
+ final DataFlowProperty.Value push = DataFlowProperty.Value.Push;
+
+ final DataFlowProperty ep = DataFlowProperty.of(pull);
+ final Integer epKeyIndex = MetricUtils.getEpKeyIndex(ep);
+ final Integer idx = MetricUtils.valueToIndex(epKeyIndex, ep);
+ // Pull is of ordinal index 0
+ Assert.assertEquals(Integer.valueOf(0), idx);
+
+ final Object pull1 = MetricUtils.indexToValue(0.5, -0.1, epKeyIndex);
+ Assert.assertEquals(pull, pull1);
+ final Object push1 = MetricUtils.indexToValue(0.5, 0.1, epKeyIndex);
+ Assert.assertEquals(push, push1);
+ final Object pull2 = MetricUtils.indexToValue(-0.5, -0.1, epKeyIndex);
+ Assert.assertEquals(pull, pull2);
+ final Object pull3 = MetricUtils.indexToValue(-0.5, 0.1, epKeyIndex);
+ Assert.assertEquals(pull, pull3);
+ final Object push2 = MetricUtils.indexToValue(2.0, 1.0, epKeyIndex);
+ Assert.assertEquals(push, push2);
+ final Object push3 = MetricUtils.indexToValue(1.1, -0.1, epKeyIndex);
+ Assert.assertEquals(push, push3);
+ final Object push4 = MetricUtils.indexToValue(1.1, 0.1, epKeyIndex);
+ Assert.assertEquals(push, push4);
+ }
+
+ @Test
+ public void testIntegerBooleanIndexAndValue() {
+ final Integer one = 1;
+ final Integer hundred = 100;
+
+ final ParallelismProperty pEp1 = ParallelismProperty.of(one);
+ final ParallelismProperty pEp100 = ParallelismProperty.of(hundred);
+ final Integer pEp1KeyIndex = MetricUtils.getEpKeyIndex(pEp1);
+ final Integer pEp100KeyIndex = MetricUtils.getEpKeyIndex(pEp100);
+ Assert.assertEquals(Integer.valueOf(1), MetricUtils.valueToIndex(pEp1KeyIndex, pEp1));
+ Assert.assertEquals(Integer.valueOf(100), MetricUtils.valueToIndex(pEp100KeyIndex, pEp100));
+
+
+ final ResourceSlotProperty rsEpT = ResourceSlotProperty.of(true);
+ final ResourceSlotProperty rsEpF = ResourceSlotProperty.of(false);
+ final Integer rsEpTKeyIndex = MetricUtils.getEpKeyIndex(rsEpT);
+ final Integer rsEpFKeyIndex = MetricUtils.getEpKeyIndex(rsEpF);
+ Assert.assertEquals(Integer.valueOf(1), MetricUtils.valueToIndex(rsEpTKeyIndex, rsEpT));
+ Assert.assertEquals(Integer.valueOf(0), MetricUtils.valueToIndex(rsEpFKeyIndex, rsEpF));
+
+ final Object one1 = MetricUtils.indexToValue(1.5, -0.1, pEp1KeyIndex);
+ final Object one2 = MetricUtils.indexToValue(0.5, 0.1, pEp1KeyIndex);
+ final Object one3 = MetricUtils.indexToValue(2.0, -0.6, pEp1KeyIndex);
+ final Object one4 = MetricUtils.indexToValue(0.0, 0.5, pEp1KeyIndex);
+ Assert.assertEquals(one, one1);
+ Assert.assertEquals(one, one2);
+ Assert.assertEquals(one, one3);
+ Assert.assertEquals(one, one4);
+
+ final Object hundred1 = MetricUtils.indexToValue(100.5, -0.1, pEp100KeyIndex);
+ final Object hundred2 = MetricUtils.indexToValue(99.5, 0.1, pEp100KeyIndex);
+ Assert.assertEquals(hundred, hundred1);
+ Assert.assertEquals(hundred, hundred2);
+
+ final Object t1 = MetricUtils.indexToValue(1.5, -0.1, rsEpTKeyIndex);
+ final Object t2 = MetricUtils.indexToValue(0.1, 0.5, rsEpTKeyIndex);
+ final Object t3 = MetricUtils.indexToValue(1.5, 0.1, rsEpTKeyIndex);
+ Assert.assertEquals(true, t1);
+ Assert.assertEquals(true, t2);
+ Assert.assertEquals(true, t3);
+
+ final Object f1 = MetricUtils.indexToValue(0.5, -0.1, rsEpFKeyIndex);
+ final Object f2 = MetricUtils.indexToValue(-0.5, 0.1, rsEpFKeyIndex);
+ final Object f3 = MetricUtils.indexToValue(-0.5, -0.1, rsEpFKeyIndex);
+ Assert.assertEquals(false, f1);
+ Assert.assertEquals(false, f2);
+ Assert.assertEquals(false, f3);
+ }
+
+ @Test
+ public void testOtherIndexAndValue() {
+ final EncoderFactory ef = new EncoderFactory.DummyEncoderFactory();
+ final DecoderFactory df = new DecoderFactory.DummyDecoderFactory();
+
+ final EncoderProperty eEp = EncoderProperty.of(ef);
+ final DecoderProperty dEp = DecoderProperty.of(df);
+ final Integer eEpKeyIndex = MetricUtils.getEpKeyIndex(eEp);
+ final Integer dEpKeyIndex = MetricUtils.getEpKeyIndex(dEp);
+ final Integer efidx = MetricUtils.valueToIndex(eEpKeyIndex, eEp);
+ final Integer dfidx = MetricUtils.valueToIndex(dEpKeyIndex, dEp);
+
+ final Object ef1 = MetricUtils.indexToValue(0.1 + efidx, -0.1, eEpKeyIndex);
+ final Object ef2 = MetricUtils.indexToValue(-0.1 + efidx, 0.1, eEpKeyIndex);
+ Assert.assertEquals("EP_INDEX: (" + eEpKeyIndex + ", " + efidx + ")", ef.toString(), ef1.toString());
+ Assert.assertEquals("EP_INDEX: (" + eEpKeyIndex + ", " + efidx + ")", ef.toString(), ef2.toString());
+
+ final Object df1 = MetricUtils.indexToValue(0.1 + dfidx, -0.1, dEpKeyIndex);
+ final Object df2 = MetricUtils.indexToValue(-0.1 + dfidx, 0.1, dEpKeyIndex);
+ Assert.assertEquals("EP_INDEX: (" + dEpKeyIndex + ", " + dfidx + ")", df.toString(), df1.toString());
+ Assert.assertEquals("EP_INDEX: (" + dEpKeyIndex + ", " + dfidx + ")", df.toString(), df2.toString());
+ }
+
+ @Test
+ public void testPairAndValueToEP() {
+ final DataFlowProperty.Value pull = DataFlowProperty.Value.Pull;
+ final DataFlowProperty ep = DataFlowProperty.of(pull);
+ final Integer epKeyIndex = MetricUtils.getEpKeyIndex(ep);
+ final Integer idx = MetricUtils.valueToIndex(epKeyIndex, ep);
+ Assert.assertEquals(Integer.valueOf(0), idx);
+
+ final ExecutionProperty<? extends Serializable> ep2 =
+ MetricUtils.keyAndValueToEP(epKeyIndex, 0.5, -0.1);
+ Assert.assertEquals(ep, ep2);
+ }
+
+ @Test
+ public void validateStaticConstructorsOfExecutionProperties() {
+ MetricUtils.EP_KEY_METADATA.values().forEach(p -> Assert.assertTrue(
+ p.left().getName() + "should have an 'of' method with its value class, " + p.right().getName(),
+ MetricUtils.getMethodFor(p.left(), "of", p.right()).isPresent()));
+ }
+}
diff --git a/runtime/driver/src/main/java/org/apache/nemo/driver/NemoDriver.java b/runtime/driver/src/main/java/org/apache/nemo/driver/NemoDriver.java
index 3a981db..afca55a 100644
--- a/runtime/driver/src/main/java/org/apache/nemo/driver/NemoDriver.java
+++ b/runtime/driver/src/main/java/org/apache/nemo/driver/NemoDriver.java
@@ -22,12 +22,13 @@ import org.apache.commons.lang3.SerializationUtils;
import org.apache.commons.lang3.concurrent.BasicThreadFactory;
import org.apache.nemo.common.ir.IdManager;
import org.apache.nemo.compiler.optimizer.pass.compiletime.annotating.ResourceSitePass;
+import org.apache.nemo.compiler.optimizer.pass.compiletime.annotating.XGBoostPass;
import org.apache.nemo.conf.JobConf;
import org.apache.nemo.runtime.common.RuntimeIdManager;
import org.apache.nemo.runtime.common.comm.ControlMessage;
+import org.apache.nemo.runtime.common.message.ClientRPC;
import org.apache.nemo.runtime.common.message.MessageParameters;
import org.apache.nemo.runtime.master.BroadcastManagerMaster;
-import org.apache.nemo.runtime.master.ClientRPC;
import org.apache.nemo.runtime.master.RuntimeMaster;
import org.apache.reef.annotations.audience.DriverSide;
import org.apache.reef.driver.client.JobMessageObserver;
@@ -111,6 +112,7 @@ public final class NemoDriver {
this.clientRPC = clientRPC;
// TODO #69: Support job-wide execution property
ResourceSitePass.setBandwidthSpecificationString(bandwidthString);
+ clientRPC.registerHandler(ControlMessage.ClientToDriverMessageType.Notification, this::handleNotification);
clientRPC.registerHandler(ControlMessage.ClientToDriverMessageType.LaunchDAG, message -> {
startSchedulingUserDAG(message.getLaunchDAG().getDag());
final Map<Serializable, Object> broadcastVars =
@@ -195,6 +197,21 @@ public final class NemoDriver {
}
/**
+ * handler for notifications from the client.
+ *
+ * @param message message from the client.
+ */
+ private void handleNotification(final ControlMessage.ClientToDriverMessage message) {
+ switch (message.getMessage().getOptimizationType()) {
+ case XGBoost:
+ XGBoostPass.pushMessage(message.getMessage().getData());
+ break;
+ default:
+ break;
+ }
+ }
+
+ /**
* Evaluator failed.
*/
public final class FailedEvaluatorHandler implements EventHandler<FailedEvaluator> {
diff --git a/runtime/driver/src/main/java/org/apache/nemo/driver/UserApplicationRunner.java b/runtime/driver/src/main/java/org/apache/nemo/driver/UserApplicationRunner.java
index 20635e0..bf759a4 100644
--- a/runtime/driver/src/main/java/org/apache/nemo/driver/UserApplicationRunner.java
+++ b/runtime/driver/src/main/java/org/apache/nemo/driver/UserApplicationRunner.java
@@ -50,6 +50,15 @@ public final class UserApplicationRunner {
private final Backend<PhysicalPlan> backend;
private final PlanRewriter planRewriter;
+ /**
+ * Constructor.
+ *
+ * @param maxScheduleAttempt maximum scheuling attempt.
+ * @param optimizer the nemo optimizer.
+ * @param backend the backend to actually execute the job.
+ * @param runtimeMaster the runtime master.
+ * @param planRewriter plan rewriter
+ */
@Inject
private UserApplicationRunner(@Parameter(JobConf.MaxTaskAttempt.class) final int maxScheduleAttempt,
final Optimizer optimizer,
diff --git a/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/datatransfer/DataTransferTest.java b/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/datatransfer/DataTransferTest.java
index e5e01dd..c20c9a5 100644
--- a/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/datatransfer/DataTransferTest.java
+++ b/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/datatransfer/DataTransferTest.java
@@ -35,6 +35,7 @@ import org.apache.nemo.common.ir.vertex.executionproperty.ScheduleGroupProperty;
import org.apache.nemo.common.test.EmptyComponents;
import org.apache.nemo.conf.JobConf;
import org.apache.nemo.runtime.common.RuntimeIdManager;
+import org.apache.nemo.runtime.common.message.ClientRPC;
import org.apache.nemo.runtime.common.message.MessageEnvironment;
import org.apache.nemo.runtime.common.message.MessageParameters;
import org.apache.nemo.runtime.common.message.PersistentConnectionToMasterMap;
@@ -50,7 +51,6 @@ import org.apache.nemo.runtime.executor.data.BlockManagerWorker;
import org.apache.nemo.runtime.executor.data.DataUtil;
import org.apache.nemo.runtime.executor.data.SerializerManager;
import org.apache.nemo.runtime.master.BlockManagerMaster;
-import org.apache.nemo.runtime.master.ClientRPC;
import org.apache.nemo.runtime.master.RuntimeMaster;
import org.apache.nemo.runtime.master.metric.MetricManagerMaster;
import org.apache.nemo.runtime.master.metric.MetricMessageHandler;
diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java
index c8d165b..ea6c31c 100644
--- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java
+++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java
@@ -29,6 +29,7 @@ import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.conf.JobConf;
import org.apache.nemo.runtime.common.RuntimeIdManager;
import org.apache.nemo.runtime.common.comm.ControlMessage;
+import org.apache.nemo.runtime.common.message.ClientRPC;
import org.apache.nemo.runtime.common.message.MessageContext;
import org.apache.nemo.runtime.common.message.MessageEnvironment;
import org.apache.nemo.runtime.common.message.MessageListener;
@@ -113,13 +114,29 @@ public final class RuntimeMaster {
private final Server metricServer;
private final MetricStore metricStore;
+ /**
+ * Constructor.
+ *
+ * @param scheduler the scheduler implementation.
+ * @param containerManager the container manager, in charge of the available containers.
+ * @param metricMessageHandler the handler for metric messages.
+ * @param masterMessageEnvironment message environment for the runtime master.
+ * @param metricManagerMaster metric manager master.
+ * @param clientRPC the RPC channel to communicate with the client.
+ * @param planStateManager the manager that keeps track of the plan state.
+ * @param jobId the Job ID, provided by the user.
+ * @param dbAddress the DB Address, provided by the user.
+ * @param dbId the ID for the given DB.
+ * @param dbPassword the password for the given DB.
+ * @param dagDirectory directory of the DAG to save the json files and metrics into.
+ */
@Inject
private RuntimeMaster(final Scheduler scheduler,
final ContainerManager containerManager,
final MetricMessageHandler metricMessageHandler,
final MessageEnvironment masterMessageEnvironment,
- final ClientRPC clientRPC,
final MetricManagerMaster metricManagerMaster,
+ final ClientRPC clientRPC,
final PlanStateManager planStateManager,
@Parameter(JobConf.JobId.class) final String jobId,
@Parameter(JobConf.DBAddress.class) final String dbAddress,
@@ -189,6 +206,12 @@ public final class RuntimeMaster {
return server;
}
+ /**
+ * Record IR DAG related metrics.
+ *
+ * @param irdag the IR DAG to record.
+ * @param planId the ID of the IR DAG Physical Plan.
+ */
public void recordIRDAGMetrics(final IRDAG irdag, final String planId) {
metricStore.getOrCreateMetric(JobMetric.class, planId).setIRDAG(irdag);
}
@@ -202,7 +225,7 @@ public final class RuntimeMaster {
metricStore.dumpAllMetricToFile(Paths.get(dagDirectory,
"Metric_" + jobId + "_" + System.currentTimeMillis() + ".json").toString());
- metricStore.saveOptimizationMetricsToDB(dbAddress, dbId, dbPassword);
+ metricStore.saveOptimizationMetricsToDB(dbAddress, jobId, dbId, dbPassword);
}
/**
diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/metric/MetricStore.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/metric/MetricStore.java
index 4bcd8ee..fabe9ac 100644
--- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/metric/MetricStore.java
+++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/metric/MetricStore.java
@@ -78,6 +78,13 @@ public final class MetricStore {
private static final MetricStore INSTANCE = new MetricStore();
}
+ /**
+ * Get the metric class by its name.
+ *
+ * @param className the name of the class.
+ * @param <T> type of the metric.
+ * @return the class of the type of the metric.
+ */
public <T extends Metric> Class<T> getMetricClassByName(final String className) {
if (!metricList.keySet().contains(className)) {
throw new NoSuchElementException();
@@ -238,13 +245,21 @@ public final class MetricStore {
/**
* Save the job metrics for the optimization to the DB, in the form of LibSVM, to a local SQLite DB.
* The metrics are as follows: the JCT (duration), and the IR DAG execution properties.
+ *
+ * @param jobId The ID of the job which we record the metrics of.
*/
- private void saveOptimizationMetricsToLocal() {
+ private void saveOptimizationMetricsToLocal(final String jobId) {
final String[] syntax = {"INTEGER PRIMARY KEY AUTOINCREMENT"};
+ try {
+ Class.forName("org.sqlite.JDBC");
+ } catch (ClassNotFoundException e) {
+ throw new MetricException("SQLite Driver not loaded: " + e);
+ }
+
try (Connection c = DriverManager.getConnection(MetricUtils.SQLITE_DB_NAME)) {
LOG.info("Opened database successfully at {}", MetricUtils.SQLITE_DB_NAME);
- saveOptimizationMetrics(c, syntax);
+ saveOptimizationMetrics(jobId, c, syntax);
} catch (SQLException e) {
LOG.error("Error while saving optimization metrics to SQLite: {}", e);
}
@@ -253,32 +268,39 @@ public final class MetricStore {
/**
* Save the job metrics for the optimization to the DB, in the form of LibSVM, to a remote DB, if applicable.
* The metrics are as follows: the JCT (duration), and the IR DAG execution properties.
+ *
+ * @param address Address to the DB.
+ * @param jobId Job ID, of which we record the metrics.
+ * @param dbId the ID of the DB.
+ * @param dbPasswd the Password to the DB.
*/
- public void saveOptimizationMetricsToDB(final String address, final String id, final String passwd) {
+ public void saveOptimizationMetricsToDB(final String address, final String jobId,
+ final String dbId, final String dbPasswd) {
final String[] syntax = {"SERIAL PRIMARY KEY"};
if (!MetricUtils.metaDataLoaded()) {
- saveOptimizationMetricsToLocal();
+ saveOptimizationMetricsToLocal(jobId);
return;
}
- try (Connection c = DriverManager.getConnection(address, id, passwd)) {
+ try (Connection c = DriverManager.getConnection(address, dbId, dbPasswd)) {
LOG.info("Opened database successfully at {}", MetricUtils.POSTGRESQL_METADATA_DB_NAME);
- saveOptimizationMetrics(c, syntax);
+ saveOptimizationMetrics(jobId, c, syntax);
} catch (SQLException e) {
LOG.error("Error while saving optimization metrics to PostgreSQL: {}", e);
LOG.info("Saving metrics on the local SQLite DB");
- saveOptimizationMetricsToLocal();
+ saveOptimizationMetricsToLocal(jobId);
}
}
/**
* Save the job metrics for the optimization to the DB, in the form of LibSVM.
*
+ * @param jobId the ID of the job.
* @param c the connection to the DB.
* @param syntax the db-specific syntax.
*/
- private void saveOptimizationMetrics(final Connection c, final String[] syntax) {
+ private void saveOptimizationMetrics(final String jobId, final Connection c, final String[] syntax) {
try (Statement statement = c.createStatement()) {
statement.setQueryTimeout(30); // set timeout to 30 sec.
@@ -305,17 +327,17 @@ public final class MetricStore {
try {
statement.executeUpdate("CREATE TABLE IF NOT EXISTS " + tableName
- + " (id " + syntax[0] + ", duration INTEGER NOT NULL, inputsize INTEGER NOT NULL, "
+ + " (id " + syntax[0] + ", duration BIGINT NOT NULL, inputsize BIGINT NOT NULL, "
+ "jvmmemsize BIGINT NOT NULL, memsize BIGINT NOT NULL, "
+ "vertex_properties TEXT NOT NULL, edge_properties TEXT NOT NULL, "
- + "created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP);");
+ + "note TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP);");
LOG.info("CREATED TABLE For {} IF NOT PRESENT", tableName);
statement.executeUpdate("INSERT INTO " + tableName
- + " (duration, inputsize, jvmmemsize, memsize, vertex_properties, edge_properties) "
+ + " (duration, inputsize, jvmmemsize, memsize, vertex_properties, edge_properties, note) "
+ "VALUES (" + duration + ", " + inputSize + ", "
+ jvmMemSize + ", " + memSize + ", '"
- + vertexProperties + "', '" + edgeProperties + "');");
+ + vertexProperties + "', '" + edgeProperties + "', '" + jobId + "');");
LOG.info("Recorded metrics on the table for {}", tableName);
} catch (SQLException e) {
LOG.error("Error while saving optimization metrics: {}", e);