You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2020/11/10 12:12:05 UTC
[systemds] branch master updated: [SYSTEMDS-2722] New built-in
function for train/test splitting
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new e8153ca [SYSTEMDS-2722] New built-in function for train/test splitting
e8153ca is described below
commit e8153caf8107374548178344261ed21c9f77275a
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Tue Nov 10 13:11:33 2020 +0100
[SYSTEMDS-2722] New built-in function for train/test splitting
This patch introduces a new dml-bodied builtin function for common
train/test splitting of feature matrices and labels. We support two
types: contiguous (ranges of rows) and sampled (uniform selection of
rows without replacement).
---
scripts/builtin/split.dml | 63 ++++++++++++++++++++++
.../java/org/apache/sysds/common/Builtins.java | 1 +
2 files changed, 64 insertions(+)
diff --git a/scripts/builtin/split.dml b/scripts/builtin/split.dml
new file mode 100644
index 0000000..b2d78d0
--- /dev/null
+++ b/scripts/builtin/split.dml
@@ -0,0 +1,63 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Split input data X and y into contiguous or samples train/test sets
+# ------------------------------------------------------------------------------
+# NAME TYPE DEFAULT MEANING
+# ------------------------------------------------------------------------------
+# X Matrix --- Input feature matrix
+# y Matrix --- Input
+# f Double 0.7 Train set fraction [0,1]
+# cont Boolean TRUE contiuous splits, otherwise sampled
+# ------------------------------------------------------------------------------
+# Xtrain Matrix --- Train split of feature matrix
+# Xtest Matrix --- Test split of feature matrix
+# ytrain Matrix --- Train split of label matrix
+# ytest Matrix --- Test split of label matrix
+# ------------------------------------------------------------------------------
+
+m_split = function(Matrix[Double] X, Matrix[Double] y, Double f=0.7, Boolean cont=TRUE)
+ return (Matrix[Double] Xtrain, Matrix[Double] Xtest, Matrix[Double] ytrain, Matrix[Double] ytest)
+{
+ # basic sanity checks
+ if( f <= 0 | f >= 1 )
+ print("Invalid train/test split configuration: f="+f);
+ if( nrow(X) != nrow(y) )
+ print("Mismatching number of rows X and y: "+nrow(X)+" "+nrow(y) )
+
+ # contiguous train/test splits
+ if( cont ) {
+ Xtrain = X[1:f*nrow(X),];
+ ytrain = y[1:f*nrow(X),];
+ Xtest = X[(nrow(Xtrain)+1):nrow(X),];
+ ytest = y[(nrow(Xtrain)+1):nrow(X),];
+ }
+ # sampled train/test splits
+ else {
+ I = rand(rows=nrow(X), cols=1) <= f;
+ P1 = removeEmpty(target=diag(I), margin="rows", select=I);
+ P2 = removeEmpty(target=diag(I==0), margin="rows", select=I==0);
+ Xtrain = P1 %*% X;
+ ytrain = P1 %*% y;
+ Xtest = P2 %*% X;
+ ytest = P2 %*% y;
+ }
+}
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java
index 2c08ef4..273a520 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -187,6 +187,7 @@ public enum Builtins {
SLICEFINDER("slicefinder", true),
SMOTE("smote", true),
SOLVE("solve", false),
+ SPLIT("split", true),
SQRT("sqrt", false),
SUM("sum", false),
SVD("svd", false, ReturnType.MULTI_RETURN),