You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ja...@apache.org on 2020/07/21 15:56:13 UTC

[systemds] branch master updated: [SYSTEMDS-1863] Full MLContext test for LinearReg

This is an automated email from the ASF dual-hosted git repository.

janardhan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new d1a1492  [SYSTEMDS-1863] Full MLContext test for LinearReg
d1a1492 is described below

commit d1a1492c2da608f7be0a5458beaadabb44b06c2b
Author: Janardhan Pulivarthi <j1...@protonmail.com>
AuthorDate: Mon Jul 20 12:18:57 2020 +0530

    [SYSTEMDS-1863] Full MLContext test for LinearReg
    
      * Takes advantage of existing R algorithm scripts used for
        codegen testing.
      * This would improve the testing by allowing us to provide all
        the necessary inputs into the script.
---
 .../org/apache/sysds/test/AutomatedTestBase.java   |  2 ++
 .../functions/mlcontext/MLContextLinregTest.java   | 38 +++++++++++++++++++---
 2 files changed, 35 insertions(+), 5 deletions(-)

diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index a66ee1e..0183e34 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -1649,6 +1649,8 @@ public abstract class AutomatedTestBase {
 	}
 
 	protected String getRScript() {
+		if(fullRScriptName != null)
+			return fullRScriptName;
 		return sourceDirectory + selectedTest + ".R";
 	}
 
diff --git a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextLinregTest.java b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextLinregTest.java
index a5cddb8..0e45cb4 100644
--- a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextLinregTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextLinregTest.java
@@ -22,8 +22,13 @@ package org.apache.sysds.test.functions.mlcontext;
 import static org.apache.sysds.api.mlcontext.ScriptFactory.dmlFromFile;
 
 import org.apache.log4j.Logger;
-import org.junit.Test;
 import org.apache.sysds.api.mlcontext.Script;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.util.HashMap;
 
 public class MLContextLinregTest extends MLContextTestBase {
 	protected static Logger log = Logger.getLogger(MLContextLinregTest.class);
@@ -37,6 +42,11 @@ public class MLContextLinregTest extends MLContextTestBase {
 		CG, DS,
 	}
 
+	private final static double eps = 1e-3;
+
+	private final static int rows = 2468;
+	private final static int cols = 507;
+
 	@Test
 	public void testLinregCGSparse() {
 		runLinregTestMLC(LinregType.CG, true);
@@ -59,24 +69,42 @@ public class MLContextLinregTest extends MLContextTestBase {
 
 	private void runLinregTestMLC(LinregType type, boolean sparse) {
 
-		double[][] X = getRandomMatrix(10, 3, 0, 1, sparse ? sparsity2 : sparsity1, 7);
-		double[][] Y = getRandomMatrix(10, 1, 0, 10, 1.0, 3);
+		double[][] X = getRandomMatrix(rows, cols, 0, 1, sparse ? sparsity2 : sparsity1, 7);
+		double[][] Y = getRandomMatrix(rows, 1, 0, 10, 1.0, 3);
+
+		// Hack Alert
+		// overwrite baseDirectory to the place where test data is stored.
+		baseDirectory = "target/testTemp/functions/mlcontext/";
+
+		fullRScriptName = "src/test/scripts/functions/codegenalg/Algorithm_LinregCG.R";
+
+		writeInputMatrixWithMTD("X", X, true);
+		writeInputMatrixWithMTD("y", Y, true);
+
+		rCmd = getRCmd(inputDir(), "0", "0.000001", "0", "0.001", expectedDir());
+		runRScript(true);
+
+		MatrixBlock outmat = new MatrixBlock();
 
 		switch (type) {
 		case CG:
 			Script lrcg = dmlFromFile(TEST_SCRIPT_CG);
 			lrcg.in("X", X).in("y", Y).in("$icpt", "0").in("$tol", "0.000001").in("$maxi", "0").in("$reg", "0.000001")
 					.out("beta_out");
-			ml.execute(lrcg);
+			outmat = ml.execute(lrcg).getMatrix("beta_out").toMatrixBlock();
 
 			break;
 
 		case DS:
 			Script lrds = dmlFromFile(TEST_SCRIPT_DS);
 			lrds.in("X", X).in("y", Y).in("$icpt", "0").in("$reg", "0.000001").out("beta_out");
-			ml.execute(lrds);
+			outmat = ml.execute(lrds).getMatrix("beta_out").toMatrixBlock();
 
 			break;
 		}
+
+		//compare matrices
+		HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromFS("w");
+		TestUtils.compareMatrices(rfile, outmat, eps);
 	}
 }