You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ar...@apache.org on 2021/05/01 19:19:42 UTC
[systemds] branch master updated: [SYSTEMDS-2959] Add a predict
function for Naive Bayes
This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 624a612 [SYSTEMDS-2959] Add a predict function for Naive Bayes
624a612 is described below
commit 624a61229f9f744958faa310a9e80a86518b3c44
Author: sfathollahzadeh <s....@gmail.com>
AuthorDate: Sat May 1 21:12:17 2021 +0200
[SYSTEMDS-2959] Add a predict function for Naive Bayes
Closes #1217
---
docs/site/builtins-reference.md | 102 ++++++++++++++-------
scripts/builtin/{naivebayes.dml => naiveBayes.dml} | 2 +-
.../builtin/naiveBayesPredict.dml | 16 +++-
.../java/org/apache/sysds/common/Builtins.java | 3 +-
.../builtin/BuiltinNaiveBayesPredictTest.java | 102 +++++++++++++++++++++
src/test/scripts/functions/builtin/NaiveBayes.dml | 2 +-
.../scripts/functions/builtin/NaiveBayesPredict.R | 66 +++++++++++++
.../{NaiveBayes.dml => NaiveBayesPredict.dml} | 24 ++++-
src/test/scripts/installDependencies.R | 1 +
9 files changed, 270 insertions(+), 48 deletions(-)
diff --git a/docs/site/builtins-reference.md b/docs/site/builtins-reference.md
index c3d21ca..c46de37 100644
--- a/docs/site/builtins-reference.md
+++ b/docs/site/builtins-reference.md
@@ -61,13 +61,14 @@ limitations under the License.
* [`gnmf`-Function](#gnmf-function)
* [`mdedup`-Function](#mdedup-function)
* [`msvm`-Function](#msvm-function)
- * [`naivebayes`-Function](#naivebayes-function)
+ * [`naiveBayes`-Function](#naiveBayes-function)
+ * [`naiveBayesPredict`-Function](#naiveBayesPredict-function)
* [`outlier`-Function](#outlier-function)
* [`toOneHot`-Function](#toOneHOt-function)
* [`winsorize`-Function](#winsorize-function)
* [`gmm`-Function](#gmm-function)
* [`correctTypos`-Function](#correcttypos-function)
-
+
# Introduction
The DML (Declarative Machine Learning) language has built-in functions which enable access to both low- and high-level functions
@@ -106,7 +107,7 @@ Note that this function is highly **unstable** and will be overworked and might
##### `data`-Argument
-The `data`-argument can be a `Matrix` of any datatype from which the elements will be taken and placed in the tensor
+The `data`-argument can be a `Matrix` of any datatype from which the elements will be taken and placed in the tensor
until filled. If given as a `Tensor` the same procedure takes place. We iterate through `Matrix` and `Tensor` by starting
with each dimension index at `0` and then incrementing the lowest one, until we made a complete pass over the dimension,
and then increasing the dimension index above. This will be done until the `Tensor` is completely filled.
@@ -167,14 +168,14 @@ confusionMatrix(P, Y)
| Y | Matrix[Double] | --- | vector of Golden standard One Hot Encoded |
### Returns
-
+
| Name | Type | Description |
| :----------- | :------------- | :---------- |
| ConfusionSum | Matrix[Double] | The Confusion Matrix Sums of classifications |
| ConfusionAvg | Matrix[Double] | The Confusion Matrix averages of each true class |
### Example
-
+
```r
numClasses = 1
z = rand(rows = 5, cols = 1, min = 1, max = 9)
@@ -385,7 +386,7 @@ beta = glm(X=X,Y=y)
## `gridSearch`-Function
The `gridSearch`-function is used to find the optimal hyper-parameters of a model which results in the most _accurate_
-predictions. This function takes `train` and `eval` functions by name.
+predictions. This function takes `train` and `eval` functions by name.
### Usage
```r
@@ -423,13 +424,13 @@ paramRanges = list(10^seq(0,-4), 10^seq(-5,-9), 10^seq(1,3))
## `hyperband`-Function
-The `hyperband`-function is used for hyper parameter optimization and is based on multi-armed bandits and early elimination.
+The `hyperband`-function is used for hyper parameter optimization and is based on multi-armed bandits and early elimination.
Through multiple parallel brackets and consecutive trials it will return the hyper parameter combination which performed best
on a validation dataset. A set of hyper parameter combinations is drawn from uniform distributions with given ranges; Those
make up the candidates for `hyperband`.
-Notes:
+Notes:
* `hyperband` is hard-coded for `lmCG`, and uses `lmPredict` for validation
-* `hyperband` is hard-coded to use the number of iterations as a resource
+* `hyperband` is hard-coded to use the number of iterations as a resource
* `hyperband` can only optimize continuous hyperparameters
### Usage
@@ -542,7 +543,7 @@ B = img_crop(img_in = A, w = 20, h = 10, x_offset = 0, y_offset = 0)
## `img_mirror`-Function
The `img_mirror`-function is an image data augumentation function.
-It flips an image on the `X` (horizontal) or `Y` (vertical) axis.
+It flips an image on the `X` (horizontal) or `Y` (vertical) axis.
### Usage
@@ -679,7 +680,7 @@ kmeans(X = X, k = 20, runs = 10, max_iter = 5000, eps = 0.000001, is_verbose = F
## `lm`-Function
The `lm`-function solves linear regression using either the **direct solve method** or the **conjugate gradient algorithm**
-depending on the input size of the matrices (See [`lmDS`-function](#lmds-function) and
+depending on the input size of the matrices (See [`lmDS`-function](#lmds-function) and
[`lmCG`-function](#lmcg-function) respectively).
### Usage
@@ -711,10 +712,10 @@ is called internally and parameters `tol` and `maxi` are ignored.
##### `icpt`-Argument
The *icpt-argument* can be set to 3 modes:
-
- * 0 = no intercept, no shifting, no rescaling
- * 1 = add intercept, but neither shift nor rescale X
- * 2 = add intercept, shift & rescale X columns to mean = 0, variance = 1
+
+* 0 = no intercept, no shifting, no rescaling
+* 1 = add intercept, but neither shift nor rescale X
+* 2 = add intercept, shift & rescale X columns to mean = 0, variance = 1
### Example
@@ -999,7 +1000,7 @@ Y= scale(X,center,scale)
Implements training phase of Sherlock: A Deep Learning Approach to Semantic Data Type Detection
-[Hulsebos, Madelon, et al. "Sherlock: A deep learning approach to semantic data type detection."
+[Hulsebos, Madelon, et al. "Sherlock: A deep learning approach to semantic data type detection."
Proceedings of the 25th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining., 2019]
### Usage
@@ -1055,7 +1056,7 @@ write(label_encoding, "weights/label_encoding")
Implements prediction and evaluation phase of Sherlock: A Deep Learning Approach to Semantic Data Type Detection
-[Hulsebos, Madelon, et al. "Sherlock: A deep learning approach to semantic data type detection."
+[Hulsebos, Madelon, et al. "Sherlock: A deep learning approach to semantic data type detection."
Proceedings of the 25th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining., 2019]
### Usage
@@ -1085,7 +1086,7 @@ sherlockPredict(X, cW1, cb1, cW2, cb2, cW3, cb3, wW1, wb1, wW2, wb2, wW3, wb3,
| Type | Description |
| :------------- | :---------- |
| Matrix[Double] | Class probabilities of shape (N, K). |
-### Example
+### Example
```r
# preprocessed validation data taken from sherlock corpus
@@ -1116,7 +1117,7 @@ fW3, fb3)
## `sigmoid`-Function
-The Sigmoid function is a type of activation function, and also defined as a squashing function which limit the output
+The Sigmoid function is a type of activation function, and also defined as a squashing function which limit the output
to a range between 0 and 1, which will make these functions useful in the prediction of probabilities.
### Usage
@@ -1149,9 +1150,9 @@ Y = sigmoid(X)
The `smote`-function (Synthetic Minority Oversampling Technique) implements a classical techniques for handling class imbalance.
The built-in takes the samples from minority class and over-sample them by generating the synthesized samples.
The built-in accepts two parameters s and k. The parameter s define the number of synthesized samples to be generated
- i.e., over-sample the minority class by s time, where s is the multiple of 100. For given 40 samples of minority class and
- s = 200 the smote will generate the 80 synthesized samples to over-sample the class by 200 percent. The parameter k is used to generate the
- k nearest neighbours for each minority class sample and then the neighbours are chosen randomly in synthesis process.
+i.e., over-sample the minority class by s time, where s is the multiple of 100. For given 40 samples of minority class and
+s = 200 the smote will generate the 80 synthesized samples to over-sample the class by 200 percent. The parameter k is used to generate the
+k nearest neighbours for each minority class sample and then the neighbours are chosen randomly in synthesis process.
### Usage
@@ -1172,7 +1173,7 @@ smote(X, s, k, verbose);
| Type | Description |
| :------------- | :---------- |
-| Matrix[Double] | Matrix of (N/100) * X synthetic minority class samples
+| Matrix[Double] | Matrix of (N/100) * X synthetic minority class samples
### Example
@@ -1184,7 +1185,7 @@ B = smote(X = X, s=200, k=3, verbose=TRUE);
## `steplm`-Function
The `steplm`-function (stepwise linear regression) implements a classical forward feature selection method.
-This method iteratively runs what-if scenarios and greedily selects the next best feature until the Akaike
+This method iteratively runs what-if scenarios and greedily selects the next best feature until the Akaike
information criterion (AIC) does not improve anymore. Each configuration trains a regression model via `lm`,
which in turn calls either the closed form `lmDS` or iterative `lmGC`.
@@ -1216,13 +1217,13 @@ steplm(X, y, icpt);
##### `icpt`-Argument
The *icpt-arg* can be set to 2 modes:
-
- * 0 = no intercept, no shifting, no rescaling
- * 1 = add intercept, but neither shift nor rescale X
+
+* 0 = no intercept, no shifting, no rescaling
+* 1 = add intercept, but neither shift nor rescale X
##### `selected`-Output
-If the best AIC is achieved without any features the matrix of *selected* features contains 0. Moreover, in this case no further statistics will be produced
+If the best AIC is achieved without any features the matrix of *selected* features contains 0. Moreover, in this case no further statistics will be produced
### Example
@@ -1271,7 +1272,7 @@ ress = slicefinder(X = X,W = w, Y = y, k = 5, paq = 1, S = 2);
## `normalize`-Function
The `normalize`-function normalises the values of a matrix by changing the dataset to use a common scale.
-This is done while preserving differences in the ranges of values.
+This is done while preserving differences in the ranges of values.
The output is a matrix of values in range [0,1].
### Usage
@@ -1341,14 +1342,14 @@ H = rand(rows = 2, cols = ncol(X), min = -0.05, max = 0.05);
gnmf(X = X, rnk = 2, eps = 10^-8, maxi = 10)
```
-## `naivebayes`-Function
+## `naiveBayes`-Function
-The `naivebayes`-function computes the class conditional probabilities and class priors.
+The `naiveBayes`-function computes the class conditional probabilities and class priors.
### Usage
```r
-naivebayes(D, C, laplace, verbose)
+naiveBayes(D, C, laplace, verbose)
```
### Arguments
@@ -1372,7 +1373,38 @@ naivebayes(D, C, laplace, verbose)
```r
D=rand(rows=10,cols=1,min=10)
C=rand(rows=10,cols=1,min=10)
-[prior, classConditionals] = naivebayes(D, C, laplace = 1, verbose = TRUE)
+[prior, classConditionals] = naiveBayes(D, C, laplace = 1, verbose = TRUE)
+```
+
+## `naiveBaysePredict`-Function
+
+The `naiveBaysePredict`-function predicts the scoring with a naive Bayes model.
+
+### Usage
+
+```r
+naiveBaysePredict(X=X, P=P, C=C)
+```
+
+### Arguments
+
+| Name | Type | Default | Description |
+| :------ | :------------- | -------- | :---------- |
+| X | Matrix[Double] | required | Matrix of test data with N rows. |
+| P | Matrix[Double] | required | Class priors, One dimensional column matrix with N rows. |
+| C | Matrix[Double] | required | Class conditional probabilities, matrix with N rows. |
+
+### Returns
+
+| Type | Description |
+| :------------- | :---------- |
+| Matrix[Double] | A matrix containing the top-K item-ids with highest predicted ratings. |
+| Matrix[Double] | A matrix containing predicted ratings. |
+
+### Example
+
+```r
+[YRaw, Y] = naiveBaysePredict(X=data, P=model_prior, C=model_conditionals)
```
## `outlier`-Function
@@ -1439,7 +1471,7 @@ y = toOneHot(X,numClasses)
## `mdedup`-Function
-The `mdedup`-function implements builtin for deduplication using matching dependencies
+The `mdedup`-function implements builtin for deduplication using matching dependencies
(e.g. Street 0.95, City 0.90 -> ZIP 1.0) by Jaccard distance.
### Usage
@@ -1553,7 +1585,7 @@ Y = winsorize(X=X)
## `gmm`-Function
-The `gmm`-function implements builtin Gaussian Mixture Model with four different types of
+The `gmm`-function implements builtin Gaussian Mixture Model with four different types of
covariance matrices i.e., VVV, EEE, VVI, VII and two initialization methods namely "kmeans" and "random".
### Usage
diff --git a/scripts/builtin/naivebayes.dml b/scripts/builtin/naiveBayes.dml
similarity index 96%
rename from scripts/builtin/naivebayes.dml
rename to scripts/builtin/naiveBayes.dml
index 1be291a..313ec09 100644
--- a/scripts/builtin/naivebayes.dml
+++ b/scripts/builtin/naiveBayes.dml
@@ -19,7 +19,7 @@
#
#-------------------------------------------------------------
-m_naivebayes = function(Matrix[Double] D, Matrix[Double] C, Double laplace = 1, Boolean verbose = TRUE)
+m_naiveBayes = function(Matrix[Double] D, Matrix[Double] C, Double laplace = 1, Boolean verbose = TRUE)
return (Matrix[Double] prior, Matrix[Double] classConditionals)
{
laplaceCorrection = laplace;
diff --git a/src/test/scripts/functions/builtin/NaiveBayes.dml b/scripts/builtin/naiveBayesPredict.dml
similarity index 74%
copy from src/test/scripts/functions/builtin/NaiveBayes.dml
copy to scripts/builtin/naiveBayesPredict.dml
index 1af2672..8ae938e 100644
--- a/src/test/scripts/functions/builtin/NaiveBayes.dml
+++ b/scripts/builtin/naiveBayesPredict.dml
@@ -19,8 +19,14 @@
#
#-------------------------------------------------------------
-X = read($1);
-y = read($2);
-[prior, conditionals] = naivebayes(D=X, C=y, laplace=$4);
-write(prior, $5);
-write(conditionals, $6);
+m_naiveBayesPredict = function(Matrix[Double] X, Matrix[Double] P, Matrix[Double] C)
+ return (Matrix[Double] YRaw, Matrix[Double] Y)
+{
+ numRows = nrow(X)
+ model = cbind(C, P)
+
+ ones = matrix(1, rows=numRows, cols=1);
+ X_w_ones = cbind(X, ones);
+ YRaw = X_w_ones %*% t(log(model));
+ Y = rowIndexMax(YRaw);
+}
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java
index 897aadc..27a0d73 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -181,7 +181,8 @@ public enum Builtins {
NCOL("ncol", false),
NORMALIZE("normalize", true),
NROW("nrow", false),
- NAIVEBAYES("naivebayes", true, false),
+ NAIVEBAYES("naiveBayes", true, false),
+ NAIVEBAYESPREDICT("naiveBayesPredict", true, false),
OUTER("outer", false),
OUTLIER("outlier", true, false), //TODO parameterize opposite
OUTLIER_SD("outlierBySd", true),
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinNaiveBayesPredictTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinNaiveBayesPredictTest.java
new file mode 100644
index 0000000..9318f45
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinNaiveBayesPredictTest.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+
+public class BuiltinNaiveBayesPredictTest extends AutomatedTestBase {
+ private final static String TEST_NAME = "NaiveBayesPredict";
+ private final static String TEST_DIR = "functions/builtin/";
+ private final static String TEST_CLASS_DIR = TEST_DIR + BuiltinNaiveBayesPredictTest.class.getSimpleName() + "/";
+ private final static int numClasses = 10;
+
+ public double eps = 1e-7;
+
+ @Override public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"YRaw", "Y"}));
+ }
+
+ @Test public void testSmallDense() {
+ testNaiveBayesPredict(100, 50, 0.7);
+ }
+
+ @Test public void testLargeDense() {
+ testNaiveBayesPredict(10000, 750, 0.7);
+ }
+
+ @Test public void testSmallSparse() {
+ testNaiveBayesPredict(100, 50, 0.01);
+ }
+
+ @Test public void testLargeSparse() {
+ testNaiveBayesPredict(10000, 750, 0.01);
+ }
+
+ public void testNaiveBayesPredict(int rows, int cols, double sparsity) {
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+
+ int classes = numClasses;
+ double laplace = 1;
+
+ List<String> proArgs = new ArrayList<>();
+ proArgs.add("-args");
+ proArgs.add(input("D"));
+ proArgs.add(input("C"));
+ proArgs.add(String.valueOf(classes));
+ proArgs.add(String.valueOf(laplace));
+ proArgs.add(output("YRaw"));
+ proArgs.add(output("Y"));
+ programArgs = proArgs.toArray(new String[proArgs.size()]);
+
+ rCmd = getRCmd(inputDir(), Integer.toString(classes), Double.toString(laplace), expectedDir());
+
+ double[][] D = getRandomMatrix(rows, cols, 0, 1, sparsity, -1);
+ double[][] C = getRandomMatrix(rows, 1, 0, 1, 1, -1);
+ for(int i = 0; i < rows; i++) {
+ C[i][0] = (int) (C[i][0] * classes) + 1;
+ C[i][0] = (C[i][0] > classes) ? classes : C[i][0];
+ }
+
+ writeInputMatrixWithMTD("D", D, true);
+ writeInputMatrixWithMTD("C", C, true);
+
+ runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+
+ runRScript(true);
+
+ HashMap<CellIndex, Double> YRawR = readRMatrixFromExpectedDir("YRaw");
+ HashMap<CellIndex, Double> YR = readRMatrixFromExpectedDir("Y");
+ HashMap<CellIndex, Double> YRawSYSTEMDS = readDMLMatrixFromOutputDir("YRaw");
+ HashMap<CellIndex, Double> YSYSTEMDS = readDMLMatrixFromOutputDir("Y");
+ TestUtils.compareMatrices(YRawR, YRawSYSTEMDS, eps, "YRawR", "YRawSYSTEMDS");
+ TestUtils.compareMatrices(YR, YSYSTEMDS, eps, "YR", "YSYSTEMDS");
+ }
+}
diff --git a/src/test/scripts/functions/builtin/NaiveBayes.dml b/src/test/scripts/functions/builtin/NaiveBayes.dml
index 1af2672..65303d3 100644
--- a/src/test/scripts/functions/builtin/NaiveBayes.dml
+++ b/src/test/scripts/functions/builtin/NaiveBayes.dml
@@ -21,6 +21,6 @@
X = read($1);
y = read($2);
-[prior, conditionals] = naivebayes(D=X, C=y, laplace=$4);
+[prior, conditionals] = naiveBayes(D=X, C=y, laplace=$4);
write(prior, $5);
write(conditionals, $6);
diff --git a/src/test/scripts/functions/builtin/NaiveBayesPredict.R b/src/test/scripts/functions/builtin/NaiveBayesPredict.R
new file mode 100644
index 0000000..18fcb9e
--- /dev/null
+++ b/src/test/scripts/functions/builtin/NaiveBayesPredict.R
@@ -0,0 +1,66 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+
+library("Matrix")
+library("naivebayes")
+
+D = as.matrix(readMM(paste(args[1], "D.mtx", sep="")))
+C = as.matrix(readMM(paste(args[1], "C.mtx", sep="")))
+laplace <- as.numeric(args[3])
+
+# divide D into "train" and "test" data
+numRows = nrow(D)
+trainSize = numRows * 0.8
+
+trainData = D[1:trainSize, ]
+testData = D[(trainSize+1):numRows, ]
+y <- factor(C[1:trainSize])
+
+# The Naive Bayes Predict need to unique column name
+features <- paste0("V", seq_len(ncol(trainData)))
+colnames(trainData) <- features
+colnames(testData) <- features
+
+# Create model base on train data
+model <- multinomial_naive_bayes(x = trainData, y = y, laplace = laplace)
+
+# The SystemDS DML scripts based on YRaw data
+# and the "naivebayes" predict function in R
+# return probabilities matrix
+# Example: YRaw <- predict(model, newdata = testData, type = "prob")
+
+# We need to return "Raw" values
+lev <- model$levels
+prior <- model$prior
+params <- t(model$params)
+YRaw <- tcrossprod(testData, log(params))
+
+for (ith_class in seq_along(lev)) {
+ YRaw[ ,ith_class] <- YRaw[ ,ith_class] + log(prior[ith_class])
+}
+
+Y <- max.col(YRaw, ties.method="last")
+
+# write out the predict
+writeMM(as(YRaw, "CsparseMatrix"), paste(args[4], "YRaw", sep=""))
+writeMM(as(Y, "CsparseMatrix"), paste(args[4], "Y", sep=""))
diff --git a/src/test/scripts/functions/builtin/NaiveBayes.dml b/src/test/scripts/functions/builtin/NaiveBayesPredict.dml
similarity index 64%
copy from src/test/scripts/functions/builtin/NaiveBayes.dml
copy to src/test/scripts/functions/builtin/NaiveBayesPredict.dml
index 1af2672..4943827 100644
--- a/src/test/scripts/functions/builtin/NaiveBayes.dml
+++ b/src/test/scripts/functions/builtin/NaiveBayesPredict.dml
@@ -19,8 +19,22 @@
#
#-------------------------------------------------------------
-X = read($1);
-y = read($2);
-[prior, conditionals] = naivebayes(D=X, C=y, laplace=$4);
-write(prior, $5);
-write(conditionals, $6);
+D = read($1);
+C = read($2);
+
+# divide data into "train" and "test" subsets
+numRows = nrow(D);
+trainSize = numRows * 0.8;
+trainData = D[1:trainSize,];
+testData = D[(trainSize+1):numRows,];
+C = C[1:trainSize,];
+
+# calc "prior" and "conditionals" with naiveBayes build-in function
+[prior, conditionals] = naiveBayes(D=trainData, C=C, laplace=$4, verbose=FALSE);
+
+# compute predict
+[YRaw,Y] = naiveBayesPredict(X=testData, P=prior, C=conditionals);
+
+# write the results
+write(YRaw, $5);
+write(Y, $6);
diff --git a/src/test/scripts/installDependencies.R b/src/test/scripts/installDependencies.R
index 8482f1e..46d6f65 100644
--- a/src/test/scripts/installDependencies.R
+++ b/src/test/scripts/installDependencies.R
@@ -61,6 +61,7 @@ custom_install("imputeTS");
custom_install("FNN");
custom_install("class");
custom_install("unbalanced");
+custom_install("naivebayes");
print("Installation Done")