You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2020/07/20 17:22:09 UTC
[systemds] branch master updated: [SYSTEMDS-2573] Hyperband
built-in function (hyper-param optimization)
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 586e910 [SYSTEMDS-2573] Hyperband built-in function (hyper-param optimization)
586e910 is described below
commit 586e910a1f6d728ab2fc80ea716cf0545946bb19
Author: OsChri <os...@yahoo.de>
AuthorDate: Mon Jul 20 18:51:36 2020 +0200
[SYSTEMDS-2573] Hyperband built-in function (hyper-param optimization)
AMLS project SS2020.
Closes #996.
---
dev/{Tasks.txt => Tasks-obsolete.txt} | 1 +
docs/site/builtins-reference.md | 50 +++++++
scripts/builtin/hyperband.dml | 162 +++++++++++++++++++++
.../java/org/apache/sysds/common/Builtins.java | 1 +
.../functions/builtin/BuiltinHyperbandTest.java | 77 ++++++++++
src/test/scripts/functions/builtin/HyperbandLM.dml | 63 ++++++++
6 files changed, 354 insertions(+)
diff --git a/dev/Tasks.txt b/dev/Tasks-obsolete.txt
similarity index 99%
rename from dev/Tasks.txt
rename to dev/Tasks-obsolete.txt
index da1436b..561ef48 100644
--- a/dev/Tasks.txt
+++ b/dev/Tasks-obsolete.txt
@@ -365,3 +365,4 @@ SYSTEMDS-520 Lineage Tracing, Reuse and Integration III
SYSTEMDS-610 Cleaning Pipelines
* 611 Initial Brute force execution OK
+
diff --git a/docs/site/builtins-reference.md b/docs/site/builtins-reference.md
index c96e5d7..6931d42 100644
--- a/docs/site/builtins-reference.md
+++ b/docs/site/builtins-reference.md
@@ -32,6 +32,7 @@ limitations under the License.
* [`discoverFD`-Function](#discoverFD-function)
* [`glm`-Function](#glm-function)
* [`gridSearch`-Function](#gridSearch-function)
+ * [`hyperband`-Function](#hyperband-function)
* [`img_brightness`-Function](#img_brightness-function)
* [`img_crop`-Function](#img_crop-function)
* [`img_mirror`-Function](#img_mirror-function)
@@ -301,6 +302,55 @@ paramRanges = list(10^seq(0,-4), 10^seq(-5,-9), 10^seq(1,3))
[B, opt]= gridSearch(X=X, y=y, train="lm", predict="lmPredict", params=params, paramValues=paramRanges, verbose = TRUE)
```
+## `hyperband`-Function
+
+The `hyperband`-function is used for hyper parameter optimization and is based on multi-armed bandits and early elimination.
+Through multiple parallel brackets and consecutive trials it will return the hyper parameter combination which performed best
+on a validation dataset. A set of hyper parameter combinations is drawn from uniform distributions with given ranges; Those
+make up the candidates for `hyperband`.
+Notes:
+* `hyperband` is hard-coded for `lmCG`, and uses `lmpredict` for validation
+* `hyperband` is hard-coded to use the number of iterations as a resource
+* `hyperband` can only optimize continuous hyperparameters
+
+### Usage
+```r
+hyperband(X_train, y_train, X_val, y_val, params, paramRanges, R, eta, verbose)
+```
+
+### Arguments
+| Name | Type | Default | Description |
+| :------ | :------------- | -------- | :---------- |
+| X_train | Matrix[Double] | required | Input Matrix of training vectors. |
+| y_train | Matrix[Double] | required | Labels for training vectors. |
+| X_val | Matrix[Double] | required | Input Matrix of validation vectors. |
+| y_val | Matrix[Double] | required | Labels for validation vectors. |
+| params | List[String] | required | List of parameters to optimize. |
+| paramRanges | Matrix[Double] | required | The min and max values for the uniform distributions to draw from. One row per hyper parameter, first column specifies min, second column max value. |
+| R | Scalar[int] | 81 | Controls number of candidates evaluated. |
+| eta | Scalar[int] | 3 | Determines fraction of candidates to keep after each trial. |
+| verbose | Boolean | `TRUE` | If `TRUE` print messages are activated. |
+
+### Returns
+| Type | Description |
+| :------------- | :---------- |
+| Matrix[Double] | 1-column matrix of weights of best performing candidate |
+| Frame[Unknown] | hyper parameters of best performing candidate |
+
+### Example
+```r
+X_train = rand(rows=50, cols=10);
+y_train = rowSums(X_train) + rand(rows=50, cols=1);
+X_val = rand(rows=50, cols=10);
+y_val = rowSums(X_val) + rand(rows=50, cols=1);
+
+params = list("reg");
+paramRanges = matrix("0 20", rows=1, cols=2);
+
+[bestWeights, optHyperParams] = hyperband(X_train=X_train, y_train=y_train,
+ X_val=X_val, y_val=y_val, params=params, paramRanges=paramRanges);
+```
+
## `img_brightness`-Function
The `img_brightness`-function is an image data augumentation function.
diff --git a/scripts/builtin/hyperband.dml b/scripts/builtin/hyperband.dml
new file mode 100644
index 0000000..3d5d326
--- /dev/null
+++ b/scripts/builtin/hyperband.dml
@@ -0,0 +1,162 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+m_hyperband = function(Matrix[Double] X_train, Matrix[Double] y_train,
+ Matrix[Double] X_val, Matrix[Double] y_val, List[String] params,
+ Matrix[Double] paramRanges, Scalar[int] R = 81, Scalar[int] eta = 3,
+ Boolean verbose = TRUE)
+ return (Matrix[Double] bestWeights, Frame[Unknown] bestHyperParams)
+{
+ # variable names follow publication where algorithm is introduced
+
+ numParams = length(params);
+
+ assert(numParams == nrow(paramRanges));
+ assert(ncol(paramRanges) == 2);
+ assert(nrow(X_train) == nrow(y_train));
+ assert(nrow(X_val) == nrow(y_val));
+ assert(ncol(X_train) == ncol(X_val));
+ assert(ncol(y_train) == ncol(y_val));
+
+ s_max = floor(log(R,eta));
+ B = (s_max + 1) * R;
+ bracketWinners = matrix(0, s_max+1, numParams+1);
+ winnerWeights = matrix(0, s_max+1, ncol(X_train));
+
+ parfor( s in s_max:0 ) {
+ debugMsgs = "--------------------------";
+
+ if( verbose ) {
+ debugMsgs = append(debugMsgs, "BRACKET s = " + s + "\n");
+ }
+
+ n = ceil(floor(B/R/(s+1)) * eta^s);
+ r = R * eta^(-s);
+
+ scoreboard = matrix(0,n,1+numParams);
+ candidateWeights = matrix(0,n,ncol(X_train));
+ # candidateWeights is not read until last round, as models are retrained
+ # from zero in every trial at the moment
+
+ # draw parameter values from uniform distribution
+ # draw e.g. regularisation factor for all the candidates at once
+ for( curParam in 1:numParams ) {
+ scoreboard[,curParam+1] =
+ rand(rows=n, cols=1, min=as.scalar(paramRanges[curParam, 1]),
+ max=as.scalar(paramRanges[curParam, 2]), pdf="uniform");
+ }
+
+ for( i in 0:s ) {
+ n_i = as.integer(floor(n * eta^(-i)));
+ r_i = as.integer(floor(r * eta^i));
+ # when using number of iterations as a resource, r_i has to be an
+ # integer; when using other types of resources, like portion of the
+ # dataset, this is not the case This implementation hard-coded
+ # iterations as the resource. floor() for r_i is not included in
+ # publication of hyperband
+
+ if( verbose ) {
+ debugMsgs = append(debugMsgs, "+++++++++++++++");
+ debugMsgs = append(debugMsgs, "i: " + i + " (current round)");
+ debugMsgs = append(debugMsgs, "n_i: " + n_i + " (number of configurations evaluated)");
+ debugMsgs = append(debugMsgs, "r_i: " + r_i + " (maximum number of iterations)\n");
+ }
+
+ parfor( curCandidate in 1:n_i ) {
+ # TODO argument list has to be passed from outside as well
+ # args is a residue from the implementation with eval("lmCG", args)
+ # init argument list
+ args = list(X=X_train, y=y_train, icpt=0, reg=1e-7,
+ tol=1e-7, maxi=r_i, verbose=TRUE);
+
+ for( curParam in 1:numParams ) {
+ # replace default values with values of the candidate at the
+ # corresponding location
+ args[as.scalar(params[curParam])] =
+ as.scalar(scoreboard[curCandidate,curParam+1]);
+ }
+ # original version
+ # weights = eval(learnAlgo, arguments);
+
+ # would be better to pass the whole list at once, this solution is error
+ # prone depending on the order of the list. hyper parameters to optimize
+ # are taken from args, as there they are reordered to be invariant to the
+ # order used at calling hyperband
+ weights = lmCG(X=X_train, y=y_train, tol=as.scalar(args[1]),
+ reg=as.scalar(args[2]), maxi=r_i, verbose=FALSE);
+
+ candidateWeights[curCandidate] = t(weights)
+ preds = lmpredict(X=X_val, w=weights);
+ scoreboard[curCandidate,1] = as.matrix(sum((y_val - preds)^2));
+ }
+
+ # reorder both matrices by same order
+ reorder = order(target=scoreboard, index.return=TRUE);
+ P = table(seq(1,n_i), reorder); # permutation matrix
+ scoreboard = P %*% scoreboard;
+ candidateWeights = P %*% candidateWeights;
+
+ if( verbose ) {
+ debugMsgs = append(debugMsgs, "validation loss | parameter values:");
+ debugMsgs = append(debugMsgs, toString(scoreboard));
+ }
+
+ numToKeep = floor(n_i/eta);
+
+ # in some cases, the list of remaining candidates would get emptied
+ if( numToKeep >= 1 ) {
+ scoreboard = scoreboard[1:numToKeep]
+ candidateWeights = candidateWeights[1:numToKeep];
+ }
+ }
+
+ if( verbose ) {
+ debugMsgs = append(debugMsgs, "Winner of Bracket: ");
+ debugMsgs = append(debugMsgs, toString(scoreboard[1]));
+ print(debugMsgs); # make print atomic because of parfor
+ }
+ bracketWinners[s+1] = scoreboard[1];
+ winnerWeights[s+1] = candidateWeights[1];
+ }
+
+ if( verbose ) {
+ print("--------------------------");
+ print("WINNERS OF EACH BRACKET (from s = 0 to s_max):");
+ print("validation loss | parameter values:");
+ print(toString(bracketWinners));
+ }
+
+ # reorder both matrices by same order
+ reorder2 = order(target=bracketWinners, index.return=TRUE);
+ P2 = table(seq(1,s_max+1), reorder2); # permutation matrix
+ bracketWinners = P2 %*% bracketWinners;
+ winnerWeights = P2 %*% winnerWeights;
+
+ bestHyperParams = as.frame(t(bracketWinners[1,2:1+numParams]));
+ bestWeights = t(winnerWeights[1]);
+
+ if( verbose ) {
+ print("Hyper parameters returned:");
+ print(toString(bestHyperParams));
+ print("Weights returned:");
+ print(toString(t(bestWeights)));
+ }
+}
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java
index d4267cc..53fc39e 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -96,6 +96,7 @@ public enum Builtins {
GLM("glm", true),
GNMF("gnmf", true),
GRID_SEARCH("gridSearch", true),
+ HYPERBAND("hyperband", true),
IFELSE("ifelse", false),
IMG_MIRROR("img_mirror", true),
IMG_BRIGHTNESS("img_brightness", true),
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinHyperbandTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinHyperbandTest.java
new file mode 100644
index 0000000..e32623a
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinHyperbandTest.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+
+public class BuiltinHyperbandTest extends AutomatedTestBase
+{
+ private final static String TEST_NAME = "HyperbandLM";
+ private final static String TEST_DIR = "functions/builtin/";
+ private final static String TEST_CLASS_DIR = TEST_DIR + BuiltinHyperbandTest.class.getSimpleName() + "/";
+
+ private final static int rows = 300;
+ private final static int cols = 20;
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME,new TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"R"}));
+ }
+
+ @Test
+ public void testHyperbandCP() {
+ runHyperband(ExecType.CP);
+ }
+
+ @Test
+ public void testHyperbandSpark() {
+ runHyperband(ExecType.SPARK);
+ }
+
+ private void runHyperband(ExecType et) {
+ ExecMode modeOld = setExecMode(et);
+ try {
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-args", input("X"), input("y"), output("R")};
+ double[][] X = getRandomMatrix(rows, cols, 0, 1, 0.8, 3);
+ double[][] y = getRandomMatrix(rows, 1, 0, 1, 0.8, 7);
+ writeInputMatrixWithMTD("X", X, true);
+ writeInputMatrixWithMTD("y", y, true);
+
+ runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+
+ //expected loss smaller than default invocation
+ Assert.assertTrue(TestUtils.readDMLBoolean(output("R")));
+ }
+ finally {
+ resetExecMode(modeOld);
+ }
+ }
+}
diff --git a/src/test/scripts/functions/builtin/HyperbandLM.dml b/src/test/scripts/functions/builtin/HyperbandLM.dml
new file mode 100644
index 0000000..34ced51
--- /dev/null
+++ b/src/test/scripts/functions/builtin/HyperbandLM.dml
@@ -0,0 +1,63 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+l2norm = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B) return (Matrix[Double] loss) {
+ loss = as.matrix(sum((y - X%*%B)^2));
+}
+
+X = read($1);
+y = read($2);
+
+# size of dataset chosen such that number of maximum iterations influences the
+# performance of candidates
+numTrSamples = 100;
+numValSamples = 100;
+
+X_train = X[1:numTrSamples,];
+y_train = y[1:numTrSamples,];
+X_val = X[(numTrSamples+1):(numTrSamples+numValSamples+1),];
+y_val = y[(numTrSamples+1):(numTrSamples+numValSamples+1),];
+X_test = X[(numTrSamples+numValSamples+2):nrow(X),];
+y_test = y[(numTrSamples+numValSamples+2):nrow(X),];
+
+params = list("reg", "tol");
+
+# only works with continuous hyper parameters in this implementation
+paramRanges = matrix(0, rows=2, cols=2);
+
+paramRanges[1,1] = 0;
+paramRanges[1,2] = 20;
+paramRanges[2,1] = 10^-10;
+paramRanges[2,2] = 10^-12;
+
+# use lmCG, because this implementation of hyperband only makes sense with
+# iterative algorithms
+[B1, optHyperParams] = hyperband(X_train=X_train, y_train=y_train, X_val=X_val,
+ y_val=y_val, params=params, paramRanges=paramRanges, R=50, eta=3, verbose=TRUE);
+
+# train reference with default values
+B2 = lmCG(X=X_train, y=y_train, verbose=FALSE);
+
+l1 = l2norm(X_test, y_test, B1);
+l2 = l2norm(X_test, y_test, B2);
+R = as.scalar(l1 <= l2);
+
+write(R, $3)