You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2022/01/20 11:07:22 UTC
[systemds] branch main updated: [SYSTEMDS-3149] Decision Tree Prediction Builtin DIA project WS2021/22 Closes #1506
This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 741be73 [SYSTEMDS-3149] Decision Tree Prediction Builtin DIA project WS2021/22 Closes #1506
741be73 is described below
commit 741be739c8659e67105a6ba66a972b1b3f7d3d11
Author: Magdalena Hinterkoerner <m....@student.tugraz.at>
AuthorDate: Wed Jan 5 14:26:12 2022 +0100
[SYSTEMDS-3149] Decision Tree Prediction Builtin
DIA project WS2021/22
Closes #1506
---
scripts/builtin/decisionTreePredict.dml | 149 +++++++++++++++++++++
.../java/org/apache/sysds/common/Builtins.java | 1 +
.../part1/BuiltinDecisionTreePredictTest.java | 87 ++++++++++++
.../functions/builtin/decisionTreePredict.dml | 25 ++++
4 files changed, 262 insertions(+)
diff --git a/scripts/builtin/decisionTreePredict.dml b/scripts/builtin/decisionTreePredict.dml
new file mode 100644
index 0000000..48c7f6f
--- /dev/null
+++ b/scripts/builtin/decisionTreePredict.dml
@@ -0,0 +1,149 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+#
+# Builtin script implementing prediction based on classification trees with scale features using prediction methods of the
+# Hummingbird paper (https://www.usenix.org/system/files/osdi20-nakandala.pdf).
+#
+# INPUT PARAMETERS:
+# ---------------------------------------------------------------------------------------------
+# NAME TYPE MEANING
+# ---------------------------------------------------------------------------------------------
+# M Matrix[Double] Decision tree matrix M, as generated by scripts/builtin/decisionTree.dml, where each column corresponds
+# to a node in the learned tree and each row contains the following information:
+# M[1,j]: id of node j (in a complete binary tree)
+# M[2,j]: Offset (no. of columns) to left child of j if j is an internal node, otherwise 0
+# M[3,j]: Feature index of the feature (scale feature id if the feature is scale or
+# categorical feature id if the feature is categorical)
+# that node j looks at if j is an internal node, otherwise 0
+# M[4,j]: Type of the feature that node j looks at if j is an internal node: holds
+# the same information as R input vector
+# M[5,j]: If j is an internal node: 1 if the feature chosen for j is scale,
+# otherwise the size of the subset of values
+# stored in rows 6,7,... if j is categorical
+# If j is a leaf node: number of misclassified samples reaching at node j
+# M[6:,j]: If j is an internal node: Threshold the example's feature value is compared
+# to is stored at M[6,j] if the feature chosen for j is scale,
+# otherwise if the feature chosen for j is categorical rows 6,7,... depict the value subset chosen for j
+# If j is a leaf node 1 if j is impure and the number of samples at j > threshold, otherwise 0
+#
+# X Matrix[Double] Feature matrix X
+#
+# strategy String Prediction strategy, can be one of ["GEMM", "TT", "PTT"], referring to "Generic matrix multiplication",
+# "Tree traversal", and "Perfect tree traversal", respectively
+# -------------------------------------------------------------------------------------------
+# OUTPUT:
+# ---------------------------------------------------------------------------------------------
+# NAME TYPE MEANING
+# ---------------------------------------------------------------------------------------------
+# Y Matrix[Double] Matrix containing the predicted labels for X
+# ---------------------------------------------------------------------------------------------
+
+m_decisionTreePredict = function(Matrix[Double] M, Matrix[Double] X, String strategy)
+ return (Matrix[Double] Y)
+{
+ if (strategy == "TT") {
+ Y = predict_TT(M, X)
+ }
+ else {
+ print ("No such strategy" + strategy)
+ Y = matrix("0", rows=0, cols=0)
+ }
+}
+
+predict_TT = function (Matrix[Double] M, Matrix[Double] X)
+ return (Matrix[Double] Y)
+{
+ Y = matrix(0, rows=1, cols=nrow(X))
+ n = ncol(M)
+ tree_depth = ceiling(log(n+1,2)) # max depth of complete binary tree
+ [N_L, N_R, N_F, N_T] = createNodeTensors(M)
+
+ parfor (k in 1:nrow(X)){
+ # iterate over every sample in X matrix
+ sample = X[k,]
+ current_node = 1
+ cnt = 1
+ while (cnt < tree_depth){
+ feature_id = as.scalar(N_F[1, current_node])
+ feature = as.scalar(sample[,feature_id]) # select feature from sample data
+ threshold = as.scalar(N_T[1, current_node])
+
+ if (feature < threshold){
+ # move on to left child node
+ next_node = as.scalar(N_L[1, current_node])
+ } else {
+ # move on to right child node
+ next_node = as.scalar(N_R[1, current_node])
+ }
+ current_node = next_node
+ cnt +=1
+ }
+
+ class = M[4, current_node]
+ Y[1, k] = class
+ }
+}
+
+createNodeTensors = function( Matrix[Double] M )
+ return ( Matrix[Double] N_L, Matrix[Double] N_R, Matrix[Double] N_F, Matrix[Double] N_T)
+{
+ N = M[1,] # all tree nodes
+ I = M[2,] # list of node offsets to their left children
+ n_nodes = ncol(N)
+
+ N_L = matrix(0, rows=1, cols=n_nodes)
+ N_R = matrix(0, rows=1, cols=n_nodes)
+ N_F = matrix(0, rows=1, cols=n_nodes)
+ N_T = matrix(0, rows=1, cols=n_nodes)
+
+ parfor (i in 1:n_nodes){
+ # if the node is an internal node, add its left and right child to the N_L and N_R tensor, respectively
+ if (as.scalar(I[1,i]) != 0){
+ offset = as.scalar(I[1, i])
+ leftChild = as.scalar(N[1, i+offset])
+ N_L[1, i] = N[1, i+offset]
+ rightChild = leftChild + 1
+
+ if (as.scalar(I[1, leftChild]) == 0 & as.scalar(I[1, rightChild]) != 0){
+ rightChild = i
+ }
+ N_R[1, i] = N[1, rightChild]
+ } else {
+ N_L[1, i] = as.matrix(i)
+ N_R[1, i] = as.matrix(i)
+ }
+
+ # if the node is an internal node, add index of the feature it evaluates
+ if (as.scalar(M[3,i]) != 0){
+ N_F[1, i] = M[3,i]
+ } else {
+ N_F[1, i] = as.matrix(1)
+ }
+
+ # if the node is an internal node, add the threshold of the feature it evaluates
+ if (as.scalar(M[6,i]) != 0){
+ N_T[1, i] = M[6,i]
+ } else {
+ N_T[1, i] = as.matrix(0)
+ }
+ }
+}
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java
index a124220..85ca3c7 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -107,6 +107,7 @@ public enum Builtins {
DBSCAN("dbscan", true),
DBSCANAPPLY("dbscanApply", true),
DECISIONTREE("decisionTree", true),
+ DECISIONTREEPREDICT("decisionTreePredict", true),
DECOMPRESS("decompress", false),
DEEPWALK("deepWalk", true),
DETECTSCHEMA("detectSchema", false),
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreePredictTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreePredictTest.java
new file mode 100644
index 0000000..04c1a53
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreePredictTest.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin.part1;
+
+import java.util.HashMap;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+public class BuiltinDecisionTreePredictTest extends AutomatedTestBase {
+ private final static String TEST_NAME = "decisionTreePredict";
+ private final static String TEST_DIR = "functions/builtin/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinDecisionTreeTest.class.getSimpleName() + "/";
+
+ private final static double eps = 1e-10;
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"C"}));
+ }
+
+ @Test
+ public void testDecisionTreePredictDefaultCP() {
+ runDecisionTreePredict(true, ExecType.CP, "TT");
+ }
+
+ @Test
+ public void testDecisionTreePredictSP() {
+ runDecisionTreePredict(true, ExecType.SPARK, "TT");
+ }
+
+ private void runDecisionTreePredict(boolean defaultProb, ExecType instType, String strategy) {
+ Types.ExecMode platformOld = setExecMode(instType);
+ try {
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-args", input("M"), input("X"), strategy, output("Y")};
+
+ double[][] X = {{0.5, 7, 0.1}, {0.5, 7, 0.7}, {-1, -0.2, 3}, {-1, -0.2, -0.8}, {-0.3, -0.7, 3}};
+ double[][] M = {{1, 2, 3, 4, 5, 6, 7}, {1, 2, 3, 0, 0, 0, 0}, {1, 2, 3, 0, 0, 0, 0},
+ {1, 1, 1, 4, 5, 6, 7}, {1, 1, 1, 0, 0, 0, 0}, {0, -0.5, 0.5, 0, 0, 0, 0}};
+
+ HashMap<MatrixValue.CellIndex, Double> expected_Y = new HashMap<>();
+ expected_Y.put(new MatrixValue.CellIndex(1, 1), 6.0);
+ expected_Y.put(new MatrixValue.CellIndex(1, 2), 7.0);
+ expected_Y.put(new MatrixValue.CellIndex(1, 3), 5.0);
+ expected_Y.put(new MatrixValue.CellIndex(1, 4), 5.0);
+ expected_Y.put(new MatrixValue.CellIndex(1, 5), 4.0);
+
+ writeInputMatrixWithMTD("M", M, true);
+ writeInputMatrixWithMTD("X", X, true);
+
+ runTest(true, false, null, -1);
+
+ HashMap<MatrixValue.CellIndex, Double> actual_Y = readDMLMatrixFromOutputDir("Y");
+
+ TestUtils.compareMatrices(expected_Y, actual_Y, eps, "Expected-DML", "Actual-DML");
+ }
+ finally {
+ rtplatform = platformOld;
+ }
+ }
+}
diff --git a/src/test/scripts/functions/builtin/decisionTreePredict.dml b/src/test/scripts/functions/builtin/decisionTreePredict.dml
new file mode 100644
index 0000000..208a827
--- /dev/null
+++ b/src/test/scripts/functions/builtin/decisionTreePredict.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+M = read($1);
+X = read($2);
+Y = decisionTreePredict(M = M, X = X, strategy = $3);
+write(Y, $4);