You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2020/11/10 12:12:05 UTC

[systemds] branch master updated: [SYSTEMDS-2722] New built-in function for train/test splitting

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new e8153ca  [SYSTEMDS-2722] New built-in function for train/test splitting
e8153ca is described below

commit e8153caf8107374548178344261ed21c9f77275a
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Tue Nov 10 13:11:33 2020 +0100

    [SYSTEMDS-2722] New built-in function for train/test splitting
    
    This patch introduces a new dml-bodied builtin function for common
    train/test splitting of feature matrices and labels. We support two
    types: contiguous (ranges of rows) and sampled (uniform selection of
    rows without replacement).
---
 scripts/builtin/split.dml                          | 63 ++++++++++++++++++++++
 .../java/org/apache/sysds/common/Builtins.java     |  1 +
 2 files changed, 64 insertions(+)

diff --git a/scripts/builtin/split.dml b/scripts/builtin/split.dml
new file mode 100644
index 0000000..b2d78d0
--- /dev/null
+++ b/scripts/builtin/split.dml
@@ -0,0 +1,63 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Split input data X and y into contiguous or samples train/test sets
+# ------------------------------------------------------------------------------
+# NAME   TYPE    DEFAULT  MEANING
+# ------------------------------------------------------------------------------
+# X      Matrix  ---      Input feature matrix
+# y      Matrix  ---      Input 
+# f      Double  0.7      Train set fraction [0,1]
+# cont   Boolean TRUE     contiuous splits, otherwise sampled
+# ------------------------------------------------------------------------------
+# Xtrain Matrix  ---      Train split of feature matrix
+# Xtest  Matrix  ---      Test split of feature matrix
+# ytrain Matrix  ---      Train split of label matrix
+# ytest  Matrix  ---      Test split of label matrix
+# ------------------------------------------------------------------------------
+
+m_split = function(Matrix[Double] X, Matrix[Double] y, Double f=0.7, Boolean cont=TRUE)
+  return (Matrix[Double] Xtrain, Matrix[Double] Xtest, Matrix[Double] ytrain, Matrix[Double] ytest) 
+{
+  # basic sanity checks
+  if( f <= 0 | f >= 1 )
+    print("Invalid train/test split configuration: f="+f);
+  if( nrow(X) != nrow(y) )
+    print("Mismatching number of rows X and y: "+nrow(X)+" "+nrow(y) )
+
+  # contiguous train/test splits
+  if( cont ) {
+    Xtrain = X[1:f*nrow(X),];
+    ytrain = y[1:f*nrow(X),];
+    Xtest = X[(nrow(Xtrain)+1):nrow(X),];
+    ytest = y[(nrow(Xtrain)+1):nrow(X),];
+  }
+  # sampled train/test splits
+  else {
+    I = rand(rows=nrow(X), cols=1) <= f;
+    P1 = removeEmpty(target=diag(I), margin="rows", select=I);
+    P2 = removeEmpty(target=diag(I==0), margin="rows", select=I==0);
+    Xtrain = P1 %*% X;
+    ytrain = P1 %*% y;
+    Xtest = P2 %*% X;
+    ytest = P2 %*% y;
+  }
+}
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java
index 2c08ef4..273a520 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -187,6 +187,7 @@ public enum Builtins {
 	SLICEFINDER("slicefinder", true),
 	SMOTE("smote", true),
 	SOLVE("solve", false),
+	SPLIT("split", true),
 	SQRT("sqrt", false),
 	SUM("sum", false),
 	SVD("svd", false, ReturnType.MULTI_RETURN),