You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by de...@apache.org on 2017/02/08 00:38:34 UTC
incubator-systemml git commit: [SYSTEMML-1235] Migrate GNMFTest to
new MLContext
Repository: incubator-systemml
Updated Branches:
refs/heads/master 326c1c00e -> 6158bfaf9
[SYSTEMML-1235] Migrate GNMFTest to new MLContext
Closes #381.
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/6158bfaf
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/6158bfaf
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/6158bfaf
Branch: refs/heads/master
Commit: 6158bfaf9079b7a3882e709cbc6d873180c5f373
Parents: 326c1c0
Author: Deron Eriksson <de...@us.ibm.com>
Authored: Tue Feb 7 16:36:51 2017 -0800
Committer: Deron Eriksson <de...@us.ibm.com>
Committed: Tue Feb 7 16:36:51 2017 -0800
----------------------------------------------------------------------
.../functions/mlcontext/GNMFTest.java | 90 ++++++++++++--------
1 file changed, 53 insertions(+), 37 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/6158bfaf/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
index 89a4363..99ab53b 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
@@ -26,18 +26,24 @@ import java.util.Collection;
import java.util.HashMap;
import java.util.List;
-import org.apache.spark.SparkContext;
+import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix;
import org.apache.spark.mllib.linalg.distributed.MatrixEntry;
+import org.apache.spark.rdd.RDD;
import org.apache.sysml.api.DMLException;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
-import org.apache.sysml.api.MLContext;
-import org.apache.sysml.api.MLContextProxy;
-import org.apache.sysml.api.MLOutput;
+import org.apache.sysml.api.mlcontext.MLContext;
+import org.apache.sysml.api.mlcontext.MLResults;
+import org.apache.sysml.api.mlcontext.Matrix;
+import org.apache.sysml.api.mlcontext.MatrixFormat;
+import org.apache.sysml.api.mlcontext.MatrixMetadata;
+import org.apache.sysml.api.mlcontext.Script;
+import org.apache.sysml.api.mlcontext.ScriptFactory;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.parser.ParseException;
import org.apache.sysml.runtime.DMLRuntimeException;
@@ -52,6 +58,7 @@ import org.apache.sysml.runtime.util.MapReduceTool;
import org.apache.sysml.test.integration.AutomatedTestBase;
import org.apache.sysml.test.utils.TestUtils;
import org.junit.Assert;
+import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -67,12 +74,26 @@ public class GNMFTest extends AutomatedTestBase
int numRegisteredInputs;
int numRegisteredOutputs;
-
+
+ private static SparkConf conf;
+ private static JavaSparkContext sc;
+ private static MLContext ml;
+
public GNMFTest(int in, int out) {
numRegisteredInputs = in;
numRegisteredOutputs = out;
}
-
+
+ @BeforeClass
+ public static void setUpClass() {
+ if (conf == null)
+ conf = SparkExecutionContext.createSystemMLSparkConf()
+ .setAppName("GNMFTest").setMaster("local");
+ if (sc == null)
+ sc = new JavaSparkContext(conf);
+ ml = new MLContext(sc);
+ }
+
@Parameters
public static Collection<Object[]> data() {
Object[][] data = new Object[][] { { 0, 0 }, { 3, 2 }, { 2, 2 }, { 2, 1 }, { 2, 0 }, { 3, 0 }};
@@ -145,43 +166,46 @@ public class GNMFTest extends AutomatedTestBase
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
RUNTIME_PLATFORM oldRT = DMLScript.rtplatform;
- MLContext mlCtx = null;
- SparkContext sc = null;
try
{
DMLScript.rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK;
-
- mlCtx = getMLContextForTesting();
- sc = mlCtx.getSparkContext();
- mlCtx.reset(true); // Cleanup config to ensure future MLContext testcases have correct 'cp.parallel.matrixmult'
-
+
+ Script script = ScriptFactory.dmlFromFile(fullDMLScriptName);
+ // set positional argument values
+ for (int argNum = 1; argNum <= proArgs.size(); argNum++) {
+ script.in("$" + argNum, proArgs.get(argNum-1));
+ }
+
// Read two matrices through RDD and one through HDFS
if(numRegisteredInputs >= 1) {
- JavaRDD<String> vIn = sc.textFile(input("v"), 2).toJavaRDD();
- mlCtx.registerInput("V", vIn, "text", m, n);
+ JavaRDD<String> vIn = sc.sc().textFile(input("v"), 2).toJavaRDD();
+ MatrixMetadata mm = new MatrixMetadata(MatrixFormat.IJV, m, n);
+ script.in("V", vIn, mm);
}
if(numRegisteredInputs >= 2) {
- JavaRDD<String> wIn = sc.textFile(input("w"), 2).toJavaRDD();
- mlCtx.registerInput("W", wIn, "text", m, k);
+ JavaRDD<String> wIn = sc.sc().textFile(input("w"), 2).toJavaRDD();
+ MatrixMetadata mm = new MatrixMetadata(MatrixFormat.IJV, m, k);
+ script.in("W", wIn, mm);
}
if(numRegisteredInputs >= 3) {
- JavaRDD<String> hIn = sc.textFile(input("h"), 2).toJavaRDD();
- mlCtx.registerInput("H", hIn, "text", k, n);
+ JavaRDD<String> hIn = sc.sc().textFile(input("h"), 2).toJavaRDD();
+ MatrixMetadata mm = new MatrixMetadata(MatrixFormat.IJV, k, n);
+ script.in("H", hIn, mm);
}
// Output one matrix to HDFS and get one as RDD
if(numRegisteredOutputs >= 1) {
- mlCtx.registerOutput("H");
+ script.out("H");
}
if(numRegisteredOutputs >= 2) {
- mlCtx.registerOutput("W");
- mlCtx.setConfig("cp.parallel.matrixmult", "false");
+ script.out("W");
+ ml.setConfigProperty("cp.parallel.matrixmult", "false");
}
- MLOutput out = mlCtx.execute(fullDMLScriptName, programArgs);
+ MLResults results = ml.execute(script);
if(numRegisteredOutputs >= 2) {
String configStr = ConfigurationManager.getDMLConfig().getConfigInfo();
@@ -190,7 +214,7 @@ public class GNMFTest extends AutomatedTestBase
}
if(numRegisteredOutputs >= 1) {
- JavaRDD<String> hOut = out.getStringRDD("H", "text");
+ RDD<String> hOut = results.getRDDStringIJV("H");
String fName = output("h");
try {
MapReduceTool.deleteFileIfExistOnHDFS( fName );
@@ -201,10 +225,11 @@ public class GNMFTest extends AutomatedTestBase
}
if(numRegisteredOutputs >= 2) {
-// Test converter: Text -> CoordinateMatrix -> BinaryBlock -> Text
-// JavaRDD<String> wOut = out.getStringRDD("W", "text");
- JavaRDD<MatrixEntry> matRDD = out.getStringRDD("W", "text").map(new StringToMatrixEntry());
- MatrixCharacteristics mcW = out.getMatrixCharacteristics("W");
+ JavaRDD<String> javaRDDStringIJV = results.getJavaRDDStringIJV("W");
+ JavaRDD<MatrixEntry> matRDD = javaRDDStringIJV.map(new StringToMatrixEntry());
+ Matrix matrix = results.getMatrix("W");
+ MatrixCharacteristics mcW = matrix.getMatrixMetadata().asMatrixCharacteristics();
+
CoordinateMatrix coordinateMatrix = new CoordinateMatrix(matRDD.rdd(), mcW.getRows(), mcW.getCols());
JavaPairRDD<MatrixIndexes, MatrixBlock> binaryRDD = RDDConverterUtilsExt.coordinateMatrixToBinaryBlock(sc, coordinateMatrix, mcW, true);
JavaRDD<String> wOut = RDDConverterUtils.binaryBlockToTextCell(binaryRDD, mcW);
@@ -227,19 +252,10 @@ public class GNMFTest extends AutomatedTestBase
HashMap<CellIndex, Double> hmHR = readRMatrixFromFS("h");
TestUtils.compareMatrices(hmWDML, hmWR, 0.000001, "hmWDML", "hmWR");
TestUtils.compareMatrices(hmHDML, hmHR, 0.000001, "hmHDML", "hmHR");
-
- //cleanup mlcontext (prevent test memory leaks)
- mlCtx.reset();
}
finally {
DMLScript.rtplatform = oldRT;
DMLScript.USE_LOCAL_SPARK_CONFIG = oldConfig;
-
- if (sc != null) {
- sc.stop();
- }
- SparkExecutionContext.resetSparkContextStatic();
- MLContextProxy.setActive(false);
}
}