You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/05/27 19:04:45 UTC

[systemds] branch master updated: [SYSTEMDS-2983] Fix missing generality gridSearch builtin primitive

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 16f4191  [SYSTEMDS-2983] Fix missing generality gridSearch builtin primitive
16f4191 is described below

commit 16f4191b1738565d66ef394948571003911d1b1c
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Thu May 27 21:03:30 2021 +0200

    [SYSTEMDS-2983] Fix missing generality gridSearch builtin primitive
    
    So far gridSearch only support train functions with a fixed set of
    parameters equivalent to lm, and did not handle different intercept
    values (which result in different model sizes) correctly. This patch
    generalized the existing primitive and adds the necessary tests for
    gridSearch multiLogReg.
    
    This rework also revealed performance issues of task-parallel grid seach
    due to contention on shared recompilation DAGs in the context of eval
    function calls. These issues will be addressed separately.
---
 scripts/builtin/gridSearch.dml                     | 22 ++++++++--------
 .../controlprogram/FunctionProgramBlock.java       |  3 ++-
 .../functions/builtin/BuiltinGridSearchTest.java   | 29 ++++++++++++++--------
 .../scripts/functions/builtin/GridSearchLM.dml     |  4 ++-
 .../{GridSearchLM.dml => GridSearchMLogreg.dml}    | 20 ++++++++-------
 .../scripts/functions/builtin/HyperbandLM3.dml     |  5 ++--
 6 files changed, 49 insertions(+), 34 deletions(-)

diff --git a/scripts/builtin/gridSearch.dml b/scripts/builtin/gridSearch.dml
index ca37745..df4d5be 100644
--- a/scripts/builtin/gridSearch.dml
+++ b/scripts/builtin/gridSearch.dml
@@ -20,8 +20,10 @@
 #-------------------------------------------------------------
 
 m_gridSearch = function(Matrix[Double] X, Matrix[Double] y, String train, String predict,
-  List[String] params, List[Unknown] paramValues, Boolean verbose = TRUE) 
-  return (Matrix[Double] B, Frame[Unknown] opt) 
+    Integer ncolB=ncol(X), List[String] params, List[Unknown] paramValues, List[Unknown]
+    trainArgs = list(X=X, y=y, icpt=0, reg=-1, tol=-1, maxi=-1, verbose=FALSE),
+    Boolean verbose = TRUE) 
+  return (Matrix[Double] B, Frame[Unknown] opt)
 {
   # Step 0) preparation of parameters, lengths, and values in convenient form
   numParams = length(params);
@@ -38,7 +40,7 @@ m_gridSearch = function(Matrix[Double] X, Matrix[Double] y, String train, String
   cumLens = rev(cumprod(rev(paramLens))/rev(paramLens));
   numConfigs = prod(paramLens);
   
-  # Step 1) materialize hyper-parameter combinations 
+  # Step 1) materialize hyper-parameter combinations
   # (simplify debugging and compared to compute negligible)
   HP = matrix(0, numConfigs, numParams);
   parfor( i in 1:nrow(HP) ) {
@@ -46,24 +48,24 @@ m_gridSearch = function(Matrix[Double] X, Matrix[Double] y, String train, String
       HP[i,j] = paramVals[j,as.scalar(((i-1)/cumLens[j,1])%%paramLens[j,1]+1)];
   }
 
-  if( verbose )
+  if( verbose ) {
+    print("GridSeach: Number of hyper-parameters: \n"+toString(paramLens));
     print("GridSeach: Hyper-parameter combinations: \n"+toString(HP));
+  }
 
   # Step 2) training/scoring of parameter combinations
   # TODO integrate cross validation
-  Rbeta = matrix(0, nrow(HP), ncol(X));
+  Rbeta = matrix(0, nrow(HP), ncolB);
   Rloss = matrix(0, nrow(HP), 1);
-  # TODO pass arguments for function call from outside
-  arguments = list(X=X, y=y, icpt=0, reg=-1, tol=-1, maxi=-1, verbose=FALSE);
 
   parfor( i in 1:nrow(HP) ) {
     # a) replace training arguments
-    largs = arguments;
+    largs = trainArgs;
     for( j in 1:numParams )
       largs[as.scalar(params[j])] = as.scalar(HP[i,j]);
     # b) core training/scoring and write-back
-    # TODO investigate rmvar handling with explicit binding (lbeta)
-    Rbeta[i,] = t(eval(train, largs));
+    lbeta = t(eval(train, largs))
+    Rbeta[i,1:ncol(lbeta)] = lbeta;
     Rloss[i,] = eval(predict, list(X, y, t(Rbeta[i,])));
   }
 
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
index 32ba97f..b23c784 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
@@ -27,6 +27,7 @@ import java.util.stream.Collectors;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.common.Types.FunctionBlock;
+import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.hops.recompile.Recompiler;
 import org.apache.sysds.hops.recompile.Recompiler.ResetType;
@@ -155,7 +156,7 @@ public class FunctionProgramBlock extends ProgramBlock implements FunctionBlock
 				LOG.error("Function output "+ varName +" is missing.");
 			else if( dat.getDataType() != diOut.getDataType() )
 				LOG.warn("Function output "+ varName +" has wrong data type: "+dat.getDataType()+".");
-			else if( dat.getValueType() != diOut.getValueType() )
+			else if( diOut.getValueType() != ValueType.UNKNOWN && dat.getValueType() != diOut.getValueType() )
 				LOG.warn("Function output "+ varName +" has wrong value type: "+dat.getValueType()+".");
 		}
 	}
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGridSearchTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGridSearchTest.java
index d5ec4f9..7d4449b 100644
--- a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGridSearchTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGridSearchTest.java
@@ -31,7 +31,8 @@ import org.apache.sysds.test.TestUtils;
 
 public class BuiltinGridSearchTest extends AutomatedTestBase
 {
-	private final static String TEST_NAME = "GridSearchLM";
+	private final static String TEST_NAME1 = "GridSearchLM";
+	private final static String TEST_NAME2 = "GridSearchMLogreg";
 	private final static String TEST_DIR = "functions/builtin/";
 	private final static String TEST_CLASS_DIR = TEST_DIR + BuiltinGridSearchTest.class.getSimpleName() + "/";
 	
@@ -40,30 +41,36 @@ public class BuiltinGridSearchTest extends AutomatedTestBase
 	
 	@Override
 	public void setUp() {
-		addTestConfiguration(TEST_NAME,new TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"R"})); 
+		addTestConfiguration(TEST_NAME1,new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1,new String[]{"R"}));
+		addTestConfiguration(TEST_NAME2,new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2,new String[]{"R"}));
 	}
 	
 	@Test
-	public void testGridSearchCP() {
-		runGridSearch(ExecType.CP);
+	public void testGridSearchLmCP() {
+		runGridSearch(TEST_NAME1, ExecType.CP);
 	}
 	
 	@Test
-	public void testGridSearchSpark() {
-		runGridSearch(ExecType.SPARK);
+	public void testGridSearchLmSpark() {
+		runGridSearch(TEST_NAME1, ExecType.SPARK);
 	}
 	
-	private void runGridSearch(ExecType et)
+	@Test
+	public void testGridSearchMLogregCP() {
+		runGridSearch(TEST_NAME2, ExecType.CP);
+	}
+	
+	private void runGridSearch(String testname, ExecType et)
 	{
 		ExecMode modeOld = setExecMode(et);
 		try {
-			loadTestConfiguration(getTestConfiguration(TEST_NAME));
+			loadTestConfiguration(getTestConfiguration(testname));
 			String HOME = SCRIPT_DIR + TEST_DIR;
 	
-			fullDMLScriptName = HOME + TEST_NAME + ".dml";
+			fullDMLScriptName = HOME + testname + ".dml";
 			programArgs = new String[] {"-args", input("X"), input("y"), output("R")};
-			double[][] X = getRandomMatrix(rows, cols, 0, 1, 0.8, -1);
-			double[][] y = getRandomMatrix(rows, 1, 0, 1, 0.8, -1);
+			double[][] X = getRandomMatrix(rows, cols, 0, 1, 0.8, 7);
+			double[][] y = getRandomMatrix(rows, 1, 1, 2, 1, 1);
 			writeInputMatrixWithMTD("X", X, true);
 			writeInputMatrixWithMTD("y", y, true);
 			
diff --git a/src/test/scripts/functions/builtin/GridSearchLM.dml b/src/test/scripts/functions/builtin/GridSearchLM.dml
index 2d7cd3a..f6ec084 100644
--- a/src/test/scripts/functions/builtin/GridSearchLM.dml
+++ b/src/test/scripts/functions/builtin/GridSearchLM.dml
@@ -34,7 +34,9 @@ ytest = y[(N+1):nrow(X),];
 
 params = list("reg", "tol", "maxi");
 paramRanges = list(10^seq(0,-4), 10^seq(-6,-12), 10^seq(1,3));
-[B1, opt] = gridSearch(Xtrain, ytrain, "lm", "l2norm", params, paramRanges, TRUE);
+trainArgs = list(X=X, y=y, icpt=0, reg=-1, tol=-1, maxi=-1, verbose=FALSE);
+[B1, opt] = gridSearch(X=Xtrain, y=ytrain, train="lm", predict="l2norm", 
+  ncolB=ncol(X), params=params, paramValues=paramRanges, trainArgs=trainArgs);
 B2 = lm(X=Xtrain, y=ytrain, verbose=FALSE);
 
 l1 = l2norm(Xtest, ytest, B1);
diff --git a/src/test/scripts/functions/builtin/GridSearchLM.dml b/src/test/scripts/functions/builtin/GridSearchMLogreg.dml
similarity index 62%
copy from src/test/scripts/functions/builtin/GridSearchLM.dml
copy to src/test/scripts/functions/builtin/GridSearchMLogreg.dml
index 2d7cd3a..ac96fff 100644
--- a/src/test/scripts/functions/builtin/GridSearchLM.dml
+++ b/src/test/scripts/functions/builtin/GridSearchMLogreg.dml
@@ -19,12 +19,13 @@
 #
 #-------------------------------------------------------------
 
-l2norm = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B) return (Matrix[Double] loss) {
-  loss = as.matrix(sum((y - X%*%B)^2));
+accuracy = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B) return (Matrix[Double] err) {
+  [M,yhat,acc] = multiLogRegPredict(X=X, B=B, Y=y, verbose=TRUE);
+  err = as.matrix(1-acc);
 }
 
 X = read($1);
-y = read($2);
+y = round(read($2));
 
 N = 200;
 Xtrain = X[1:N,];
@@ -32,13 +33,14 @@ ytrain = y[1:N,];
 Xtest = X[(N+1):nrow(X),];
 ytest = y[(N+1):nrow(X),];
 
-params = list("reg", "tol", "maxi");
-paramRanges = list(10^seq(0,-4), 10^seq(-6,-12), 10^seq(1,3));
-[B1, opt] = gridSearch(Xtrain, ytrain, "lm", "l2norm", params, paramRanges, TRUE);
-B2 = lm(X=Xtrain, y=ytrain, verbose=FALSE);
+params = list("icpt", "reg", "maxii");
+paramRanges = list(seq(0,2),10^seq(1,-6), 10^seq(1,3));
+trainArgs = list(X=Xtrain, Y=ytrain, icpt=-1, reg=-1, tol=1e-9, maxi=100, maxii=-1, verbose=FALSE);
+[B1,opt] = gridSearch(Xtrain, ytrain, "multiLogReg", "accuracy", ncol(X)+1, params, paramRanges, trainArgs, TRUE);
+B2 = multiLogReg(X=Xtrain, Y=ytrain, verbose=TRUE);
 
-l1 = l2norm(Xtest, ytest, B1);
-l2 = l2norm(Xtest, ytest, B2);
+l1 = accuracy(Xtest, ytest, B1);
+l2 = accuracy(Xtest, ytest, B2);
 R = as.scalar(l1 < l2);
 
 write(R, $3)
diff --git a/src/test/scripts/functions/builtin/HyperbandLM3.dml b/src/test/scripts/functions/builtin/HyperbandLM3.dml
index 9dbdbba..e2b23fa 100644
--- a/src/test/scripts/functions/builtin/HyperbandLM3.dml
+++ b/src/test/scripts/functions/builtin/HyperbandLM3.dml
@@ -42,8 +42,9 @@ paramRanges = matrix("0 20", rows=1, cols=2);
   X_val=X_val, y_val=y_val, params=params, paramRanges=paramRanges);
 
 paramRanges2 = list(10^seq(0,-4))
-[bestWeights, optHyperParams2] = gridSearch(X=X_train, y=y_train,
-  train="lm", predict="l2norm", params=params, paramValues=paramRanges2);
+trainArgs = list(X=X_train, y=y_train, icpt=0, reg=-1, tol=1e-9, maxi=0, verbose=FALSE);
+[bestWeights, optHyperParams2] = gridSearch(X=X_train, y=y_train, ncolB=ncol(X),
+  train="lm", predict="l2norm", trainArgs=trainArgs, params=params, paramValues=paramRanges2);
 
 print(toString(optHyperParams))
 print(toString(optHyperParams2))