You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by de...@apache.org on 2017/02/08 00:38:34 UTC

incubator-systemml git commit: [SYSTEMML-1235] Migrate GNMFTest to new MLContext

Repository: incubator-systemml
Updated Branches:
  refs/heads/master 326c1c00e -> 6158bfaf9


[SYSTEMML-1235] Migrate GNMFTest to new MLContext

Closes #381.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/6158bfaf
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/6158bfaf
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/6158bfaf

Branch: refs/heads/master
Commit: 6158bfaf9079b7a3882e709cbc6d873180c5f373
Parents: 326c1c0
Author: Deron Eriksson <de...@us.ibm.com>
Authored: Tue Feb 7 16:36:51 2017 -0800
Committer: Deron Eriksson <de...@us.ibm.com>
Committed: Tue Feb 7 16:36:51 2017 -0800

----------------------------------------------------------------------
 .../functions/mlcontext/GNMFTest.java           | 90 ++++++++++++--------
 1 file changed, 53 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/6158bfaf/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
index 89a4363..99ab53b 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
@@ -26,18 +26,24 @@ import java.util.Collection;
 import java.util.HashMap;
 import java.util.List;
 
-import org.apache.spark.SparkContext;
+import org.apache.spark.SparkConf;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.api.java.function.Function;
 import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix;
 import org.apache.spark.mllib.linalg.distributed.MatrixEntry;
+import org.apache.spark.rdd.RDD;
 import org.apache.sysml.api.DMLException;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
-import org.apache.sysml.api.MLContext;
-import org.apache.sysml.api.MLContextProxy;
-import org.apache.sysml.api.MLOutput;
+import org.apache.sysml.api.mlcontext.MLContext;
+import org.apache.sysml.api.mlcontext.MLResults;
+import org.apache.sysml.api.mlcontext.Matrix;
+import org.apache.sysml.api.mlcontext.MatrixFormat;
+import org.apache.sysml.api.mlcontext.MatrixMetadata;
+import org.apache.sysml.api.mlcontext.Script;
+import org.apache.sysml.api.mlcontext.ScriptFactory;
 import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.parser.ParseException;
 import org.apache.sysml.runtime.DMLRuntimeException;
@@ -52,6 +58,7 @@ import org.apache.sysml.runtime.util.MapReduceTool;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.utils.TestUtils;
 import org.junit.Assert;
+import org.junit.BeforeClass;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
@@ -67,12 +74,26 @@ public class GNMFTest extends AutomatedTestBase
 	
 	int numRegisteredInputs;
 	int numRegisteredOutputs;
-	
+
+	private static SparkConf conf;
+	private static JavaSparkContext sc;
+	private static MLContext ml;
+
 	public GNMFTest(int in, int out) {
 		numRegisteredInputs = in;
 		numRegisteredOutputs = out;
 	}
-	
+
+	@BeforeClass
+	public static void setUpClass() {
+		if (conf == null)
+			conf = SparkExecutionContext.createSystemMLSparkConf()
+				.setAppName("GNMFTest").setMaster("local");
+		if (sc == null)
+			sc = new JavaSparkContext(conf);
+		ml = new MLContext(sc);
+	}
+
 	@Parameters
 	 public static Collection<Object[]> data() {
 	   Object[][] data = new Object[][] { { 0, 0 }, { 3, 2 }, { 2, 2 }, { 2, 1 }, { 2, 0 }, { 3, 0 }};
@@ -145,43 +166,46 @@ public class GNMFTest extends AutomatedTestBase
 		DMLScript.USE_LOCAL_SPARK_CONFIG = true;
 		RUNTIME_PLATFORM oldRT = DMLScript.rtplatform;
 		
-		MLContext mlCtx = null;
-		SparkContext sc = null;
 		try 
 		{
 			DMLScript.rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK;
-		
-			mlCtx = getMLContextForTesting();
-			sc = mlCtx.getSparkContext();
-			mlCtx.reset(true); // Cleanup config to ensure future MLContext testcases have correct 'cp.parallel.matrixmult'
-			
+
+			Script script = ScriptFactory.dmlFromFile(fullDMLScriptName);
+			// set positional argument values
+			for (int argNum = 1; argNum <= proArgs.size(); argNum++) {
+				script.in("$" + argNum, proArgs.get(argNum-1));
+			}
+
 			// Read two matrices through RDD and one through HDFS
 			if(numRegisteredInputs >= 1) {
-				JavaRDD<String> vIn = sc.textFile(input("v"), 2).toJavaRDD();
-				mlCtx.registerInput("V", vIn, "text", m, n);
+				JavaRDD<String> vIn = sc.sc().textFile(input("v"), 2).toJavaRDD();
+				MatrixMetadata mm = new MatrixMetadata(MatrixFormat.IJV, m, n);
+				script.in("V", vIn, mm);
 			}
 			
 			if(numRegisteredInputs >= 2) {
-				JavaRDD<String> wIn = sc.textFile(input("w"), 2).toJavaRDD();
-				mlCtx.registerInput("W", wIn, "text", m, k);
+				JavaRDD<String> wIn = sc.sc().textFile(input("w"), 2).toJavaRDD();
+				MatrixMetadata mm = new MatrixMetadata(MatrixFormat.IJV, m, k);
+				script.in("W", wIn, mm);
 			}
 			
 			if(numRegisteredInputs >= 3) {
-				JavaRDD<String> hIn = sc.textFile(input("h"), 2).toJavaRDD();
-				mlCtx.registerInput("H", hIn, "text", k, n);
+				JavaRDD<String> hIn = sc.sc().textFile(input("h"), 2).toJavaRDD();
+				MatrixMetadata mm = new MatrixMetadata(MatrixFormat.IJV, k, n);
+				script.in("H", hIn, mm);
 			}
 			
 			// Output one matrix to HDFS and get one as RDD
 			if(numRegisteredOutputs >= 1) {
-				mlCtx.registerOutput("H");
+				script.out("H");
 			}
 			
 			if(numRegisteredOutputs >= 2) {
-				mlCtx.registerOutput("W");
-				mlCtx.setConfig("cp.parallel.matrixmult", "false");
+				script.out("W");
+				ml.setConfigProperty("cp.parallel.matrixmult", "false");
 			}
 			
-			MLOutput out = mlCtx.execute(fullDMLScriptName, programArgs);
+			MLResults results = ml.execute(script);
 			
 			if(numRegisteredOutputs >= 2) {
 				String configStr = ConfigurationManager.getDMLConfig().getConfigInfo();
@@ -190,7 +214,7 @@ public class GNMFTest extends AutomatedTestBase
 			}
 			
 			if(numRegisteredOutputs >= 1) {
-				JavaRDD<String> hOut = out.getStringRDD("H", "text");
+				RDD<String> hOut = results.getRDDStringIJV("H");
 				String fName = output("h");
 				try {
 					MapReduceTool.deleteFileIfExistOnHDFS( fName );
@@ -201,10 +225,11 @@ public class GNMFTest extends AutomatedTestBase
 			}
 			
 			if(numRegisteredOutputs >= 2) {
-//				Test converter: Text -> CoordinateMatrix -> BinaryBlock -> Text
-//				JavaRDD<String> wOut = out.getStringRDD("W", "text");
-				JavaRDD<MatrixEntry> matRDD = out.getStringRDD("W", "text").map(new StringToMatrixEntry());
-				MatrixCharacteristics mcW = out.getMatrixCharacteristics("W");
+				JavaRDD<String> javaRDDStringIJV = results.getJavaRDDStringIJV("W");
+				JavaRDD<MatrixEntry> matRDD = javaRDDStringIJV.map(new StringToMatrixEntry());
+				Matrix matrix = results.getMatrix("W");
+				MatrixCharacteristics mcW = matrix.getMatrixMetadata().asMatrixCharacteristics();
+
 				CoordinateMatrix coordinateMatrix = new CoordinateMatrix(matRDD.rdd(), mcW.getRows(), mcW.getCols());
 				JavaPairRDD<MatrixIndexes, MatrixBlock> binaryRDD = RDDConverterUtilsExt.coordinateMatrixToBinaryBlock(sc, coordinateMatrix, mcW, true);
 				JavaRDD<String> wOut = RDDConverterUtils.binaryBlockToTextCell(binaryRDD, mcW);
@@ -227,19 +252,10 @@ public class GNMFTest extends AutomatedTestBase
 			HashMap<CellIndex, Double> hmHR = readRMatrixFromFS("h");
 			TestUtils.compareMatrices(hmWDML, hmWR, 0.000001, "hmWDML", "hmWR");
 			TestUtils.compareMatrices(hmHDML, hmHR, 0.000001, "hmHDML", "hmHR");
-			
-			//cleanup mlcontext (prevent test memory leaks)
-			mlCtx.reset();
 		}
 		finally {
 			DMLScript.rtplatform = oldRT;
 			DMLScript.USE_LOCAL_SPARK_CONFIG = oldConfig;
-
-			if (sc != null) {
-				sc.stop();
-			}
-			SparkExecutionContext.resetSparkContextStatic();
-			MLContextProxy.setActive(false);
 		}
 	}