You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2020/04/13 20:08:44 UTC

[systemml] branch master updated: [SYSTEMDS-118] New generic gridSearch builtin function

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new acfe388  [SYSTEMDS-118] New generic gridSearch builtin function
acfe388 is described below

commit acfe3883a50b827e78db45d0db901a3f448add20
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Mon Apr 13 22:05:52 2020 +0200

    [SYSTEMDS-118] New generic gridSearch builtin function
    
    This patch adds a new generic grid search function for hyper-parameter
    optimization of arbitrary ML algorithms and parameter combinations. This
    function takes train and eval functions by name as well as lists of
    parameter names and vectors of their values, and returns the parameter
    combination and model that gave the best results.
    
    So far hyper-parameter optimization is working, but the core
    training/scoring part needs additional features on list data types
    (e.g., list-list append, and eval fcalls with lists of unnamed and named
    parameters). Also, before it can be applied in practice it needs an
    integration with cross validation.
---
 docs/Tasks.txt                                     |  2 +-
 scripts/builtin/gridSearch.dml                     | 80 +++++++++++++++++++++
 .../java/org/apache/sysds/common/Builtins.java     |  1 +
 .../functions/builtin/BuiltinGridSearchTest.java   | 82 ++++++++++++++++++++++
 .../scripts/functions/builtin/GridSearchLM.dml     | 44 ++++++++++++
 5 files changed, 208 insertions(+), 1 deletion(-)

diff --git a/docs/Tasks.txt b/docs/Tasks.txt
index c4fa46f..5ae71b1 100644
--- a/docs/Tasks.txt
+++ b/docs/Tasks.txt
@@ -91,7 +91,7 @@ SYSTEMDS-110 New Builtin Functions
  * 115 Builtin function for model debugging (slice finder)            OK
  * 116 Builtin function for kmeans                                    OK
  * 117 Builtin function for lm cross validation                       OK
- * 118 Builtin function for hyperparameter grid search with CVlm
+ * 118 Builtin function for hyperparameter grid search                
  * 119 Builtin functions for l2svm and msvm                           OK
 
 SYSTEMDS-120 Performance Features
diff --git a/scripts/builtin/gridSearch.dml b/scripts/builtin/gridSearch.dml
new file mode 100644
index 0000000..227b863
--- /dev/null
+++ b/scripts/builtin/gridSearch.dml
@@ -0,0 +1,80 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+m_gridSearch = function(Matrix[Double] X, Matrix[Double] y, String train, String predict,
+  List[String] params, List[Unknown] paramValues, Boolean verbose = TRUE) 
+  return (Matrix[Double] B, Frame[Unknown] opt) 
+{
+  # Step 0) preparation of parameters, lengths, and values in convenient form
+  numParams = length(params);
+  paramLens = matrix(0, numParams, 1);
+  for( j in 1:numParams ) {
+    vect = as.matrix(paramValues[j,1]);
+    paramLens[j,1] = nrow(vect);
+  }
+  paramVals = matrix(0, numParams, max(paramLens));
+  for( j in 1:numParams ) {
+    vect = as.matrix(paramValues[j,1]);
+    paramVals[j,1:nrow(vect)] = t(vect);
+  }
+	cumLens = rev(cumprod(rev(paramLens))/rev(paramLens));
+	numConfigs = prod(paramLens);
+  
+  # Step 1) materialize hyper-parameter combinations 
+  # (simplify debugging and compared to compute negligible)
+  HP = matrix(0, numConfigs, numParams);
+  parfor( i in 1:nrow(HP) ) {
+    for( j in 1:numParams )
+      HP[i,j] = paramVals[j,as.scalar(((i-1)/cumLens[j,1])%%paramLens[j,1]+1)];
+  }
+
+  if( verbose )
+    print("GridSeach: Hyper-parameter combinations: \n"+toString(HP));
+
+  # Step 2) training/scoring of parameter combinations
+  # TODO integrate cross validation
+  Rbeta = matrix(0, nrow(HP), ncol(X));
+  Rloss = matrix(0, nrow(HP), 1);
+  arguments = list(X=X, y=y);
+
+  parfor( i in 1:nrow(HP) ) {
+    # a) prepare training arguments
+    largs = arguments;
+    for( j in 1:numParams ) {
+      key = as.scalar(params[j]);
+      value = as.scalar(HP[i,j]);
+      largs = append(largs, list(key=value));
+    }
+
+    # b) core training/scoring
+    lbeta = eval(train, largs);
+    lloss = eval(predict, list(X, y, lbeta));
+
+    # c) write models and loss back to output
+    Rbeta[i,] = lbeta;
+    Rloss[i,] = lloss;
+  }
+
+  # Step 3) select best parameter combination
+  ix = as.scalar(rowIndexMin(t(Rloss)));
+  B = Rbeta[ix,];          # optimal model
+  opt = as.frame(HP[ix,]); # optimal hyper-parameters
+}
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java
index 4f20d87..60b7c18 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -90,6 +90,7 @@ public enum Builtins {
 	EVAL("eval", false),
 	FLOOR("floor", false),
 	GNMF("gnmf", true),
+	GRID_SEARCH("gridSearch", true),
 	IFELSE("ifelse", false),
 	IMG_MIRROR("img_mirror", true),
 	IMG_BRIGHTNESS("img_brightness", true),
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGridSearchTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGridSearchTest.java
new file mode 100644
index 0000000..556a7d7
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGridSearchTest.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+
+
+public class BuiltinGridSearchTest extends AutomatedTestBase
+{
+	private final static String TEST_NAME = "GridSearchLM";
+	private final static String TEST_DIR = "functions/builtin/";
+	private final static String TEST_CLASS_DIR = TEST_DIR + BuiltinGridSearchTest.class.getSimpleName() + "/";
+	
+	private final static int rows = 300;
+	private final static int cols = 20;
+	
+	@Override
+	public void setUp() {
+		addTestConfiguration(TEST_NAME,new TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"R"})); 
+	}
+	
+	@Test
+	public void testGridSearchCP() {
+		//TODO additional list features needed
+		//runGridSearch(ExecType.CP);
+	}
+	
+	@Test
+	public void testGridSearchSpark() {
+		//TODO additional list features needed
+		//runGridSearch(ExecType.SPARK);
+	}
+	
+	@SuppressWarnings("unused")
+	private void runGridSearch(ExecType et)
+	{
+		ExecMode modeOld = setExecMode(et);
+		try {
+			loadTestConfiguration(getTestConfiguration(TEST_NAME));
+			String HOME = SCRIPT_DIR + TEST_DIR;
+	
+			fullDMLScriptName = HOME + TEST_NAME + ".dml";
+			programArgs = new String[] {"-args", input("X"), input("y"), output("R")};
+			double[][] X = getRandomMatrix(rows, cols, 0, 1, 0.8, -1);
+			double[][] y = getRandomMatrix(rows, 1, 0, 1, 0.8, -1);
+			writeInputMatrixWithMTD("X", X, true);
+			writeInputMatrixWithMTD("y", y, true);
+			
+			runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+			
+			//expected loss smaller than default invocation
+			Assert.assertTrue(TestUtils.readDMLBoolean(output("R")));
+		}
+		finally {
+			resetExecMode(modeOld);
+		}
+	}
+}
diff --git a/src/test/scripts/functions/builtin/GridSearchLM.dml b/src/test/scripts/functions/builtin/GridSearchLM.dml
new file mode 100644
index 0000000..9b33713
--- /dev/null
+++ b/src/test/scripts/functions/builtin/GridSearchLM.dml
@@ -0,0 +1,44 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+l2norm = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B) return (Double loss) {
+  loss = sum((y - X%*%B)^2);
+}
+
+X = read($1);
+y = read($2);
+
+N = 200;
+Xtrain = X[1:N,];
+ytrain = y[1:N,];
+Xtest = X[(N+1):nrow(X),];
+ytest = y[(N+1):nrow(X),];
+
+params = list("reg", "tol", "maxi");
+paramRanges = list(10^seq(0,-4), 10^seq(-5,-9), 10^seq(1,3));
+[B1, opt] = gridSearch(Xtrain, ytrain, "lm", "lmPredict", params, paramRanges, TRUE);
+B2 = lm(X=Xtrain, y=ytrain, verbose=FALSE);
+
+l1 = l2norm(Xtest, ytest, B1);
+l2 = l2norm(Xtest, ytest, B2);
+R = l1 <= l2;
+
+write(R, $3)