You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ja...@apache.org on 2020/07/21 15:56:13 UTC
[systemds] branch master updated: [SYSTEMDS-1863] Full MLContext
test for LinearReg
This is an automated email from the ASF dual-hosted git repository.
janardhan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new d1a1492 [SYSTEMDS-1863] Full MLContext test for LinearReg
d1a1492 is described below
commit d1a1492c2da608f7be0a5458beaadabb44b06c2b
Author: Janardhan Pulivarthi <j1...@protonmail.com>
AuthorDate: Mon Jul 20 12:18:57 2020 +0530
[SYSTEMDS-1863] Full MLContext test for LinearReg
* Takes advantage of existing R algorithm scripts used for
codegen testing.
* This would improve the testing by allowing us to provide all
the necessary inputs into the script.
---
.../org/apache/sysds/test/AutomatedTestBase.java | 2 ++
.../functions/mlcontext/MLContextLinregTest.java | 38 +++++++++++++++++++---
2 files changed, 35 insertions(+), 5 deletions(-)
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index a66ee1e..0183e34 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -1649,6 +1649,8 @@ public abstract class AutomatedTestBase {
}
protected String getRScript() {
+ if(fullRScriptName != null)
+ return fullRScriptName;
return sourceDirectory + selectedTest + ".R";
}
diff --git a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextLinregTest.java b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextLinregTest.java
index a5cddb8..0e45cb4 100644
--- a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextLinregTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextLinregTest.java
@@ -22,8 +22,13 @@ package org.apache.sysds.test.functions.mlcontext;
import static org.apache.sysds.api.mlcontext.ScriptFactory.dmlFromFile;
import org.apache.log4j.Logger;
-import org.junit.Test;
import org.apache.sysds.api.mlcontext.Script;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.util.HashMap;
public class MLContextLinregTest extends MLContextTestBase {
protected static Logger log = Logger.getLogger(MLContextLinregTest.class);
@@ -37,6 +42,11 @@ public class MLContextLinregTest extends MLContextTestBase {
CG, DS,
}
+ private final static double eps = 1e-3;
+
+ private final static int rows = 2468;
+ private final static int cols = 507;
+
@Test
public void testLinregCGSparse() {
runLinregTestMLC(LinregType.CG, true);
@@ -59,24 +69,42 @@ public class MLContextLinregTest extends MLContextTestBase {
private void runLinregTestMLC(LinregType type, boolean sparse) {
- double[][] X = getRandomMatrix(10, 3, 0, 1, sparse ? sparsity2 : sparsity1, 7);
- double[][] Y = getRandomMatrix(10, 1, 0, 10, 1.0, 3);
+ double[][] X = getRandomMatrix(rows, cols, 0, 1, sparse ? sparsity2 : sparsity1, 7);
+ double[][] Y = getRandomMatrix(rows, 1, 0, 10, 1.0, 3);
+
+ // Hack Alert
+ // overwrite baseDirectory to the place where test data is stored.
+ baseDirectory = "target/testTemp/functions/mlcontext/";
+
+ fullRScriptName = "src/test/scripts/functions/codegenalg/Algorithm_LinregCG.R";
+
+ writeInputMatrixWithMTD("X", X, true);
+ writeInputMatrixWithMTD("y", Y, true);
+
+ rCmd = getRCmd(inputDir(), "0", "0.000001", "0", "0.001", expectedDir());
+ runRScript(true);
+
+ MatrixBlock outmat = new MatrixBlock();
switch (type) {
case CG:
Script lrcg = dmlFromFile(TEST_SCRIPT_CG);
lrcg.in("X", X).in("y", Y).in("$icpt", "0").in("$tol", "0.000001").in("$maxi", "0").in("$reg", "0.000001")
.out("beta_out");
- ml.execute(lrcg);
+ outmat = ml.execute(lrcg).getMatrix("beta_out").toMatrixBlock();
break;
case DS:
Script lrds = dmlFromFile(TEST_SCRIPT_DS);
lrds.in("X", X).in("y", Y).in("$icpt", "0").in("$reg", "0.000001").out("beta_out");
- ml.execute(lrds);
+ outmat = ml.execute(lrds).getMatrix("beta_out").toMatrixBlock();
break;
}
+
+ //compare matrices
+ HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromFS("w");
+ TestUtils.compareMatrices(rfile, outmat, eps);
}
}