You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/12/31 19:16:43 UTC
[systemds] branch main updated: [SYSTEMDS-3265] Fix gridSearch for multi-class classification
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new d573a1f [SYSTEMDS-3265] Fix gridSearch for multi-class classification
d573a1f is described below
commit d573a1f15e94da053b9b2153eee3b15c4c25fcac
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Fri Dec 31 20:11:54 2021 +0100
[SYSTEMDS-3265] Fix gridSearch for multi-class classification
This patch generalizes the existing grid search second-order builtin
function to properly handle multi-column models as used during
multi-class classification. Models are now temporarily reshaped to
vectors (which is a no-op for dense models) and returned in
linearized form. The caller can then reshape it back with knowledge
of the number of classes and use it.
---
scripts/builtin/gridSearch.dml | 13 ++++++-----
.../builtin/part1/BuiltinGridSearchTest.java | 26 +++++++++++++++++-----
.../functions/builtin/GridSearchMLogreg.dml | 4 +++-
3 files changed, 31 insertions(+), 12 deletions(-)
diff --git a/scripts/builtin/gridSearch.dml b/scripts/builtin/gridSearch.dml
index eab7756..a8c0986 100644
--- a/scripts/builtin/gridSearch.dml
+++ b/scripts/builtin/gridSearch.dml
@@ -30,8 +30,8 @@
# y Matrix[Double] --- Input Matrix of vectors.
# train String --- Name ft of the train function to call via ft(trainArgs)
# predict String --- Name fp of the loss function to call via fp((predictArgs,B))
-# numB Integer --- Maximum number of parameters in model B (pass the maximum because the
-# size of B may vary with parameters like icpt
+# numB Integer --- Maximum number of parameters in model B (pass the max because the size
+# may vary with parameters like icpt or multi-class classification)
# params List[String] --- List of varied hyper-parameter names
# paramValues List[Unknown] --- List of matrices providing the parameter values as
# columnvectors for position-aligned hyper-parameters in 'params'
@@ -52,6 +52,7 @@
# NAME TYPE MEANING
# ----------------------------------------------------------------------------------------------------------------------
# B Matrix[Double] Matrix[Double]the trained model with minimal loss (by the 'predict' function)
+# Multi-column models are returned as a column-major linearized column vector
# opt Matrix[Double] one-row frame w/ optimal hyperparameters (by 'params' position)
#-----------------------------------------------------------------------------------------------------------------------
@@ -127,10 +128,10 @@ m_gridSearch = function(Matrix[Double] X, Matrix[Double] y, String train, String
ltrainArgs['X'] = rbind(tmpX);
ltrainArgs['y'] = rbind(tmpy);
lbeta = t(eval(train, ltrainArgs));
- cvbeta[,1:ncol(lbeta)] = cvbeta[,1:ncol(lbeta)] + lbeta;
+ cvbeta[,1:length(lbeta)] = cvbeta[,1:length(lbeta)] + matrix(lbeta, 1, length(lbeta));
lpredictArgs[1] = as.matrix(testX);
lpredictArgs[2] = as.matrix(testy);
- cvloss += eval(predict, append(lpredictArgs,t(lbeta)));
+ cvloss += eval(predict, append(lpredictArgs, t(lbeta)));
}
Rbeta[i,] = cvbeta / cvk; # model averaging
Rloss[i,] = cvloss / cvk;
@@ -145,8 +146,8 @@ m_gridSearch = function(Matrix[Double] X, Matrix[Double] y, String train, String
ltrainArgs[as.scalar(params[j])] = as.scalar(HP[i,j]);
# b) core training/scoring and write-back
lbeta = t(eval(train, ltrainArgs))
- Rbeta[i,1:ncol(lbeta)] = lbeta;
- Rloss[i,] = eval(predict, append(predictArgs,t(lbeta)));
+ Rbeta[i,1:length(lbeta)] = matrix(lbeta, 1, length(lbeta));
+ Rloss[i,] = eval(predict, append(predictArgs, t(lbeta)));
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGridSearchTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGridSearchTest.java
index a8d1310..6cc6411 100644
--- a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGridSearchTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGridSearchTest.java
@@ -38,8 +38,8 @@ public class BuiltinGridSearchTest extends AutomatedTestBase
private final static String TEST_DIR = "functions/builtin/";
private final static String TEST_CLASS_DIR = TEST_DIR + BuiltinGridSearchTest.class.getSimpleName() + "/";
- private final static int rows = 400;
- private final static int cols = 20;
+ private final static int _rows = 400;
+ private final static int _cols = 20;
private boolean _codegen = false;
@Override
@@ -106,7 +106,22 @@ public class BuiltinGridSearchTest extends AutomatedTestBase
runGridSearch(TEST_NAME4, ExecMode.HYBRID, false);
}
- private void runGridSearch(String testname, ExecMode et, boolean codegen)
+ @Test
+ public void testGridSearchMLogreg4CP() {
+ runGridSearch(TEST_NAME2, ExecMode.SINGLE_NODE, 10, 4, false);
+ }
+
+ @Test
+ public void testGridSearchMLogreg4Hybrid() {
+ runGridSearch(TEST_NAME2, ExecMode.HYBRID, 10, 4, false);
+ }
+
+
+ private void runGridSearch(String testname, ExecMode et, boolean codegen) {
+ runGridSearch(testname, et, _cols, 2, codegen); //binary classification
+ }
+
+ private void runGridSearch(String testname, ExecMode et, int cols, int nc, boolean codegen)
{
ExecMode modeOld = setExecMode(et);
_codegen = codegen;
@@ -117,8 +132,9 @@ public class BuiltinGridSearchTest extends AutomatedTestBase
fullDMLScriptName = HOME + testname + ".dml";
programArgs = new String[] {"-args", input("X"), input("y"), output("R")};
- double[][] X = getRandomMatrix(rows, cols, 0, 1, 0.8, 7);
- double[][] y = getRandomMatrix(rows, 1, 1, 2, 1, 1);
+ double max = testname.equals(TEST_NAME2) ? nc : 2;
+ double[][] X = getRandomMatrix(_rows, cols, 0, 1, 0.8, 7);
+ double[][] y = getRandomMatrix(_rows, 1, 1, max, 1, 1);
writeInputMatrixWithMTD("X", X, true);
writeInputMatrixWithMTD("y", y, true);
diff --git a/src/test/scripts/functions/builtin/GridSearchMLogreg.dml b/src/test/scripts/functions/builtin/GridSearchMLogreg.dml
index ec2bf9d..ce54d5d 100644
--- a/src/test/scripts/functions/builtin/GridSearchMLogreg.dml
+++ b/src/test/scripts/functions/builtin/GridSearchMLogreg.dml
@@ -26,6 +26,7 @@ accuracy = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B) return
X = read($1);
y = round(read($2));
+nc = max(y);
N = 200;
Xtrain = X[1:N,];
@@ -36,10 +37,11 @@ ytest = y[(N+1):nrow(X),];
params = list("icpt", "reg", "maxii");
paramRanges = list(seq(0,2),10^seq(1,-6), 10^seq(1,3));
trainArgs = list(X=Xtrain, Y=ytrain, icpt=-1, reg=-1, tol=1e-9, maxi=100, maxii=-1, verbose=FALSE);
-[B1,opt] = gridSearch(X=Xtrain, y=ytrain, train="multiLogReg", predict="accuracy", numB=ncol(X)+1,
+[B1,opt] = gridSearch(X=Xtrain, y=ytrain, train="multiLogReg", predict="accuracy", numB=(ncol(X)+1)*(nc-1),
params=params, paramValues=paramRanges, trainArgs=trainArgs, verbose=TRUE);
B2 = multiLogReg(X=Xtrain, Y=ytrain, verbose=TRUE);
+B1 = matrix(B1, nrow(B1)/(nc-1), (nc-1), FALSE)
l1 = accuracy(Xtest, ytest, B1);
l2 = accuracy(Xtest, ytest, B2);
R = as.scalar(l1 < l2);